1use darling::{FromAttributes, FromMeta};
2use proc_macro::{Span, TokenStream};
3use proc_macro2::{Delimiter, Group, Punct};
4use quote::{quote, ToTokens, TokenStreamExt};
5use syn::{
6 parse_macro_input, punctuated::Punctuated, token::Comma, Attribute, DataEnum, DataStruct,
7 DeriveInput, Expr, Field, Fields, Generics, Ident, Lifetime, LifetimeParam,
8};
9
10#[derive(Debug)]
11struct DataTag([u8; 4]);
12
13impl FromMeta for DataTag {
14 fn from_string(value: &str) -> darling::Result<Self> {
15 assert!(value.len() <= 4, "Tag cannot be longer than 4 bytes");
16 assert!(!value.is_empty(), "Tag cannot be empty");
17
18 let mut out = [0u8; 4];
19
20 let input = value.as_bytes();
21 let len = input.len().min(4);
23 out[0..len].copy_from_slice(input);
24
25 Ok(Self(out))
26 }
27}
28
29impl ToTokens for DataTag {
30 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
31 tokens.append(Punct::new('&', proc_macro2::Spacing::Joint));
32 let [a, b, c, d] = &self.0;
33 let inner_stream = quote!(#a, #b, #c, #d);
34 tokens.append(Group::new(Delimiter::Bracket, inner_stream));
35 }
36}
37
38#[derive(Debug, FromAttributes)]
39#[darling(attributes(tdf), forward_attrs(allow, doc, cfg))]
40struct TdfFieldAttrs {
41 tag: Option<DataTag>,
42 #[darling(default)]
43 into: Option<Expr>,
44 #[darling(default)]
45 skip: bool,
46}
47
48#[derive(Debug, FromAttributes)]
49#[darling(attributes(tdf), forward_attrs(allow, doc, cfg))]
50struct TdfStructAttr {
51 #[darling(default)]
52 group: bool,
53 #[darling(default)]
54 prefix_two: bool,
55}
56
57#[derive(Debug, FromAttributes)]
58#[darling(attributes(tdf), forward_attrs(allow, doc, cfg))]
59struct TdfEnumVariantAttr {
60 #[darling(default)]
61 default: bool,
62}
63
64#[derive(Debug, FromAttributes)]
65#[darling(attributes(tdf), forward_attrs(allow, doc, cfg))]
66struct TdfTaggedEnumVariantAttr {
67 pub key: Option<Expr>,
68
69 #[darling(default)]
70 pub tag: Option<DataTag>,
71
72 #[darling(default)]
73 pub prefix_two: bool,
74
75 #[darling(default)]
76 pub default: bool,
77
78 #[darling(default)]
79 pub unset: bool,
80}
81
82#[proc_macro_derive(TdfSerialize, attributes(tdf))]
83pub fn derive_tdf_serialize(input: TokenStream) -> TokenStream {
84 let input: DeriveInput = parse_macro_input!(input);
85
86 match &input.data {
87 syn::Data::Struct(data) => impl_serialize_struct(&input, data),
88 syn::Data::Enum(data) => {
89 if is_enum_tagged(data) {
90 impl_serialize_tagged_enum(&input, data)
91 } else {
92 impl_serialize_repr_enum(&input, data)
93 }
94 }
95 syn::Data::Union(_) => panic!("TdfSerialize cannot be implemented on union types"),
96 }
97}
98
99#[proc_macro_derive(TdfTyped, attributes(tdf))]
100pub fn derive_tdf_typed(input: TokenStream) -> TokenStream {
101 let input: DeriveInput = parse_macro_input!(input);
102
103 match &input.data {
104 syn::Data::Struct(data) => impl_type_struct(&input, data),
105 syn::Data::Enum(data) => {
106 if is_enum_tagged(data) {
107 impl_type_tagged_enum(&input, data)
108 } else {
109 impl_type_repr_enum(&input, data)
110 }
111 }
112 syn::Data::Union(_) => panic!("TdfTyped cannot be implemented on union types"),
113 }
114}
115
116#[proc_macro_derive(TdfDeserialize, attributes(tdf))]
117pub fn derive_tdf_deserialize(input: TokenStream) -> TokenStream {
118 let input: DeriveInput = parse_macro_input!(input);
119 match &input.data {
120 syn::Data::Struct(data) => impl_deserialize_struct(&input, data),
121 syn::Data::Enum(data) => {
122 if is_enum_tagged(data) {
123 impl_deserialize_tagged_enum(&input, data)
124 } else {
125 impl_deserialize_repr_enum(&input, data)
126 }
127 }
128
129 syn::Data::Union(_) => panic!("TdfDeserialize cannot be implemented on union types"),
130 }
131}
132
133fn get_repr_attribute(attrs: &[Attribute]) -> Option<Ident> {
134 attrs
135 .iter()
136 .filter_map(|attr| attr.meta.require_list().ok())
137 .find(|value| value.path.is_ident("repr"))
138 .map(|attr| {
139 let value: Ident = attr.parse_args().expect("Failed to parse repr type");
140 value
141 })
142}
143
144fn is_enum_tagged(data: &DataEnum) -> bool {
148 data.variants
149 .iter()
150 .any(|variant| !variant.fields.is_empty())
151}
152
153fn impl_type_struct(input: &DeriveInput, _data: &DataStruct) -> TokenStream {
154 let attr =
155 TdfStructAttr::from_attributes(&input.attrs).expect("Failed to parse tdf struct attrs");
156
157 assert!(
158 attr.group,
159 "Cannot derive TdfTyped on non group struct, type is unknown"
160 );
161
162 let ident = &input.ident;
163 let generics = &input.generics;
164 let where_clause = generics.where_clause.as_ref();
165
166 quote! {
167 impl #generics tdf::TdfTyped for #ident #generics #where_clause {
168 const TYPE: tdf::TdfType = tdf::TdfType::Group;
169 }
170 }
171 .into()
172}
173
174fn impl_type_repr_enum(input: &DeriveInput, _data: &DataEnum) -> TokenStream {
175 let ident = &input.ident;
176 let repr = get_repr_attribute(&input.attrs)
177 .expect("Non-tagged enums require #[repr({ty})] to be specified");
178
179 quote! {
180 impl tdf::TdfTyped for #ident {
181 const TYPE: tdf::TdfType = <#repr as tdf::TdfTyped>::TYPE;
182 }
183 }
184 .into()
185}
186
187fn impl_type_tagged_enum(input: &DeriveInput, _data: &DataEnum) -> TokenStream {
188 let ident = &input.ident;
189
190 let generics = &input.generics;
191 let where_clause = generics.where_clause.as_ref();
192
193 quote! {
194 impl #generics tdf::TdfTyped for #ident #generics #where_clause {
195 const TYPE: tdf::TdfType = tdf::TdfType::TaggedUnion;
196 }
197 }
198 .into()
199}
200
201fn tag_field_serialize(
202 field: &Field,
203 into: Option<Expr>,
204 tag: Option<DataTag>,
205 is_struct: bool,
206) -> proc_macro2::TokenStream {
207 let tag = tag.expect("Fields that arent skipped must specify a tag");
208 let ident = &field.ident;
209 let ty = &field.ty;
210
211 let value = if is_struct {
214 quote!(&self.#ident)
215 } else {
216 quote!(#ident)
217 };
218
219 if let Some(into) = into {
220 quote!( w.tag_owned::<#into>(#tag, <#ty as Into::<#into>>::into(*#value)); )
221 } else {
222 quote! ( w.tag_ref::<#ty>(#tag, #value); )
223 }
224}
225
226fn impl_serialize_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
227 let attr =
228 TdfStructAttr::from_attributes(&input.attrs).expect("Failed to parse tdf struct attrs");
229 let ident = &input.ident;
230 let generics = &input.generics;
231 let where_clause = generics.where_clause.as_ref();
232
233 let serialize_impls = data.fields.iter().filter_map(|field| {
234 let attributes =
235 TdfFieldAttrs::from_attributes(&field.attrs).expect("Failed to parse tdf field attrs");
236 if attributes.skip {
237 None
238 } else {
239 Some(tag_field_serialize(
240 field,
241 attributes.into,
242 attributes.tag,
243 true,
244 ))
245 }
246 });
247
248 let mut leading = None;
249 let mut trailing = None;
250
251 if attr.group {
252 if attr.prefix_two {
253 leading = Some(quote! { w.write_byte(2); });
254 }
255
256 trailing = Some(quote!( w.tag_group_end();));
257 }
258
259 quote! {
260 impl #generics tdf::TdfSerialize for #ident #generics #where_clause {
261 fn serialize<S: tdf::TdfSerializer>(&self, w: &mut S) {
262 #leading
263 #(#serialize_impls)*
264 #trailing
265 }
266 }
267 }
268 .into()
269}
270
271fn impl_serialize_repr_enum(input: &DeriveInput, _data: &DataEnum) -> TokenStream {
272 let ident = &input.ident;
273 let repr = get_repr_attribute(&input.attrs)
274 .expect("Non-tagged enums require #[repr({ty})] to be specified");
275
276 quote! {
277 impl tdf::TdfSerializeOwned for #ident {
278 fn serialize_owned<S: tdf::TdfSerializer>(self, w: &mut S) {
279 <#repr as tdf::TdfSerializeOwned>::serialize_owned(self as #repr, w);
280 }
281 }
282
283 impl tdf::TdfSerialize for #ident {
284 #[inline]
285 fn serialize<S: tdf::TdfSerializer>(&self, w: &mut S) {
286 tdf::TdfSerializeOwned::serialize_owned(*self, w);
287 }
288 }
289 }
290 .into()
291}
292
293fn impl_serialize_tagged_enum(input: &DeriveInput, data: &DataEnum) -> TokenStream {
294 let ident = &input.ident;
295
296 let field_impls: Vec<_> = data
297 .variants
298 .iter()
299 .map(|variant| {
300 let attr: TdfTaggedEnumVariantAttr =
301 TdfTaggedEnumVariantAttr::from_attributes(&variant.attrs)
302 .expect("Failed to parse tdf field attrs");
303
304 (variant, attr)
305 })
306 .map(|(variant, attr)| {
307 let var_ident = &variant.ident;
308 let value_tag = attr.tag;
309 let is_unit = attr.unset || attr.default;
310
311 if let Fields::Unit = &variant.fields {
314 assert!(
315 is_unit,
316 "Only unset or default enum variants can have no content"
317 );
318
319 return quote! {
320 Self::#var_ident => {
321 w.write_byte(tdf::types::tagged_union::TAGGED_UNSET_KEY);
322 }
323 };
324 }
325
326 assert!(
327 !is_unit,
328 "Enum variants with fields cannot be used as the default or unset variant"
329 );
330
331 let discriminant = attr.key.expect("Missing discriminant key");
332 let value_tag = value_tag.expect("Missing value tag");
333
334 match &variant.fields {
335 Fields::Named(fields) => {
337 let (idents, impls): (Vec<_>, Vec<_>) = fields
338 .named
339 .iter()
340 .filter_map(|field| {
341 let attributes = TdfFieldAttrs::from_attributes(&field.attrs)
342 .expect("Failed to parse tdf field attrs");
343 if attributes.skip {
344 return None;
345 }
346
347 Some((field, attributes))
348 })
349 .map(|(field, attributes)| {
350 let ident = field.ident.as_ref().expect("Field missing ident");
351 let serialize = tag_field_serialize(field, attributes.into,attributes.tag, false);
352 (ident, serialize)
353 })
354 .unzip();
355
356 let field_names: proc_macro2::TokenStream = if idents.is_empty() {
358 quote!(..)
359 } else if idents.len() != fields.named.len() {
360 quote!(#(#idents,)* ..)
361 } else {
362 quote!(#(#idents),*)
363 };
364
365 let mut leading = None;
366
367 if attr.prefix_two {
368 leading = Some(quote!( w.write_byte(2); ))
369 }
370
371 quote! {
372 Self::#var_ident { #field_names } => {
373 w.write_byte(#discriminant);
374 tdf::Tagged::serialize_raw(w, #value_tag, tdf::TdfType::Group);
375
376 #leading
377 #(#impls)*
378 w.tag_group_end();
379 }
380 }
381 }
382
383 Fields::Unnamed(fields) => {
385 let fields = &fields.unnamed;
386 let field = fields.first().expect("Unnamed tagged enum missing field");
387
388 assert!(
389 fields.len() == 1,
390 "Tagged union cannot have more than one unnamed field"
391 );
392
393 let field_ty = &field.ty;
394
395 quote! {
396 Self::#var_ident(value) => {
397 w.write_byte(#discriminant);
398 tdf::Tagged::serialize_raw(w, #value_tag, <#field_ty as tdf::TdfTyped>::TYPE);
399
400 <#field_ty as tdf::TdfSerialize>::serialize(value, w);
401 }
402 }
403 }
404 Fields::Unit => unreachable!("Unit types should already be handled above"),
405 }
406 })
407 .collect();
408 let generics = &input.generics;
409 let where_clause = generics.where_clause.as_ref();
410
411 quote! {
412 impl #generics tdf::TdfSerialize for #ident #generics #where_clause {
413 fn serialize<S: tdf::TdfSerializer>(&self, w: &mut S) {
414 match self {
415 #(#field_impls),*
416 }
417 }
418 }
419 }
420 .into()
421}
422
423fn get_deserialize_lifetime(generics: &Generics) -> LifetimeParam {
430 let mut lifetimes = generics.lifetimes();
431
432 let lifetime = lifetimes
433 .next()
434 .cloned()
435 .unwrap_or_else(|| LifetimeParam::new(Lifetime::new("'_", Span::call_site().into())));
437
438 assert!(
439 lifetimes.next().is_none(),
440 "Deserializable structs cannot have more than one lifetime"
441 );
442
443 lifetime
444}
445
446fn tag_field_deserialize(field: &Field) -> proc_macro2::TokenStream {
449 let attributes =
450 TdfFieldAttrs::from_attributes(&field.attrs).expect("Failed to parse tdf field attrs");
451
452 let ident = &field.ident;
453 let ty = &field.ty;
454
455 if attributes.skip {
456 quote!( let #ident = Default::default(); )
457 } else {
458 let tag = attributes
459 .tag
460 .expect("Fields that arent skipped must specify a tag");
461
462 if let Some(into) = attributes.into {
464 quote!( let #ident = <#ty as From<#into>>::from(r.tag::<#into>(#tag)?); )
465 } else {
466 quote!( let #ident = r.tag::<#ty>(#tag)?; )
467 }
468 }
469}
470
471fn impl_deserialize_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
472 let attributes =
473 TdfStructAttr::from_attributes(&input.attrs).expect("Failed to parse tdf struct attrs");
474
475 let ident = &input.ident;
476
477 let generics = &input.generics;
478 let lifetime = get_deserialize_lifetime(generics);
479 let where_clause = generics.where_clause.as_ref();
480
481 let idents = data.fields.iter().filter_map(|field| field.ident.as_ref());
482 let impls = data.fields.iter().map(tag_field_deserialize);
483
484 let mut trailing = None;
485
486 if attributes.group {
489 trailing = Some(quote!( tdf::GroupSlice::deserialize_content_skip(r)?; ));
490 }
491
492 quote! {
493 impl #generics tdf::TdfDeserialize<#lifetime> for #ident #generics #where_clause {
494 fn deserialize(r: &mut tdf::TdfDeserializer<#lifetime>) -> tdf::DecodeResult<Self> {
495 #(#impls)*
496 #trailing
497 Ok(Self {
498 #(#idents),*
499 })
500 }
501 }
502 }
503 .into()
504}
505
506fn impl_deserialize_repr_enum(input: &DeriveInput, data: &DataEnum) -> TokenStream {
507 let repr = get_repr_attribute(&input.attrs)
508 .expect("Non-tagged enums require #[repr({ty})] to be specified");
509
510 let mut default = None;
511
512 let variant_cases: Vec<_> = data
513 .variants
514 .iter()
515 .map(|variant| {
516 let attr = TdfEnumVariantAttr::from_attributes(&variant.attrs)
517 .expect("Failed to parse tdf enum variant attrs");
518 (variant, attr)
519 })
520 .filter(|(variant, attr)| {
521 if !attr.default {
522 return true;
523 }
524
525 assert!(
526 default.is_none(),
527 "Cannot have more than one default variant"
528 );
529
530 let ident = &variant.ident;
531
532 default = Some(quote!(_ => Self::#ident));
533
534 false
535 })
536 .map(|(variant, _attr)| {
537 let var_ident = &variant.ident;
538 let (_, discriminant) = variant
539 .discriminant
540 .as_ref()
541 .expect("Repr enum variants must include a descriminant for each value");
542
543 quote! ( #discriminant => Self::#var_ident )
544 })
545 .collect();
546
547 let ident = &input.ident;
548 let default = default.unwrap_or_else(
549 || quote!(_ => return Err(tdf::DecodeError::Other("Missing fallback enum variant"))),
550 );
551
552 quote! {
553 impl tdf::TdfDeserialize<'_> for #ident {
554 fn deserialize(r: &mut tdf::TdfDeserializer<'_>) -> tdf::DecodeResult<Self> {
555 let value = <#repr as tdf::TdfDeserialize<'_>>::deserialize(r)?;
556 Ok(match value {
557 #(#variant_cases,)*
558 #default
559 })
560 }
561 }
562 }
563 .into()
564}
565
566fn impl_deserialize_tagged_enum(input: &DeriveInput, data: &DataEnum) -> TokenStream {
567 let generics = &input.generics;
568 let lifetime = get_deserialize_lifetime(generics);
569 let where_clause = generics.where_clause.as_ref();
570
571 let mut has_unset = false;
572 let mut has_default = false;
573
574 let mut impls: Punctuated<proc_macro2::TokenStream, Comma> = data
575 .variants
576 .iter()
577 .map(|variant| {
578 let attr: TdfTaggedEnumVariantAttr =
579 TdfTaggedEnumVariantAttr::from_attributes(&variant.attrs)
580 .expect("Failed to parse tdf field attrs");
581
582 let var_ident = &variant.ident;
583 let is_unit = attr.unset || attr.default;
584
585 if let Fields::Unit = &variant.fields {
586 assert!(
587 is_unit,
588 "Only unset or default enum variants can have no content"
589 );
590
591 assert!(
592 !(attr.default && attr.unset),
593 "Enum variant cannot be default and unset"
594 );
595
596 return if attr.default {
597 assert!(!has_default, "Default variant already defined");
598 has_default = true;
599
600 quote! {
601 _ => {
602 let tag = tdf::Tagged::deserialize_owned(r)?;
603 tag.ty.skip(r, false)?;
604 Self::#var_ident
605 }
606 }
607 } else {
608 assert!(!has_unset, "Unset variant already defined");
609 has_unset = true;
610 quote!( tdf::types::tagged_union::TAGGED_UNSET_KEY => Self::#var_ident )
611 };
612 }
613
614 assert!(
615 !is_unit,
616 "Enum variants with fields cannot be used as the default or unset variant"
617 );
618
619 let discriminant = attr.key.expect("Missing discriminant key");
620 let _value_tag = attr.tag.expect("Missing value tag");
621
622 match &variant.fields {
623 Fields::Named(fields) => {
625 let (idents, impls): (Vec<_>, Vec<_>) = fields
626 .named
627 .iter()
628 .map(|field| {
629 let ident = field.ident.as_ref().unwrap();
630 let value = tag_field_deserialize(field);
631 (ident, value)
632 })
633 .unzip();
634
635 quote! {
636 #discriminant => {
637 let tag = tdf::Tagged::deserialize_owned(r)?;
638
639 #(#impls)*
640 tdf::GroupSlice::deserialize_content_skip(r)?;
641
642 Self::#var_ident {
643 #(#idents),*
644 }
645 }
646 }
647 }
648 Fields::Unnamed(fields) => {
650 let fields = &fields.unnamed;
651 let field = fields.first().expect("Unnamed tagged enum missing field");
652
653 assert!(
654 fields.len() == 1,
655 "Tagged union cannot have more than one unnamed field"
656 );
657
658 let field_ty = &field.ty;
659
660 quote! {
661 #discriminant => {
662 let tag = tdf::Tagged::deserialize_owned(r)?;
663
664 let value = <#field_ty as tdf::TdfDeserialize<'_>>::deserialize(r)?;
665 Self::#var_ident(value)
666 }
667 }
668 }
669
670 Fields::Unit => unreachable!("Unit types should already be handled above"),
671 }
672 })
673 .collect();
674
675 if !has_unset {
676 impls.push(quote!(
678 tdf::types::tagged_union::TAGGED_UNSET_KEY => return Err(tdf::DecodeError::Other("Missing unset enum variant"))
679 ));
680 }
681
682 if !has_default {
683 impls.push(quote!(
685 _ => return Err(tdf::DecodeError::Other("Missing default enum variant"))
686 ));
687 }
688
689 let ident = &input.ident;
690
691 quote! {
692 impl #generics tdf::TdfDeserialize<#lifetime> for #ident #generics #where_clause {
693 fn deserialize(r: &mut tdf::TdfDeserializer<#lifetime>) -> tdf::DecodeResult<Self> {
694 let discriminant = <u8 as tdf::TdfDeserialize<#lifetime>>::deserialize(r)?;
695
696 Ok(match discriminant {
697 #impls
698 })
699 }
700 }
701 }
702 .into()
703}