tdf_derive/
lib.rs

1use darling::{FromAttributes, FromMeta};
2use proc_macro::{Span, TokenStream};
3use proc_macro2::{Delimiter, Group, Punct};
4use quote::{quote, ToTokens, TokenStreamExt};
5use syn::{
6    parse_macro_input, punctuated::Punctuated, token::Comma, Attribute, DataEnum, DataStruct,
7    DeriveInput, Expr, Field, Fields, Generics, Ident, Lifetime, LifetimeParam,
8};
9
10#[derive(Debug)]
11struct DataTag([u8; 4]);
12
13impl FromMeta for DataTag {
14    fn from_string(value: &str) -> darling::Result<Self> {
15        assert!(value.len() <= 4, "Tag cannot be longer than 4 bytes");
16        assert!(!value.is_empty(), "Tag cannot be empty");
17
18        let mut out = [0u8; 4];
19
20        let input = value.as_bytes();
21        // Only copy the max of 4 bytes
22        let len = input.len().min(4);
23        out[0..len].copy_from_slice(input);
24
25        Ok(Self(out))
26    }
27}
28
29impl ToTokens for DataTag {
30    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
31        tokens.append(Punct::new('&', proc_macro2::Spacing::Joint));
32        let [a, b, c, d] = &self.0;
33        let inner_stream = quote!(#a, #b, #c, #d);
34        tokens.append(Group::new(Delimiter::Bracket, inner_stream));
35    }
36}
37
38#[derive(Debug, FromAttributes)]
39#[darling(attributes(tdf), forward_attrs(allow, doc, cfg))]
40struct TdfFieldAttrs {
41    tag: Option<DataTag>,
42    #[darling(default)]
43    into: Option<Expr>,
44    #[darling(default)]
45    skip: bool,
46}
47
48#[derive(Debug, FromAttributes)]
49#[darling(attributes(tdf), forward_attrs(allow, doc, cfg))]
50struct TdfStructAttr {
51    #[darling(default)]
52    group: bool,
53    #[darling(default)]
54    prefix_two: bool,
55}
56
57#[derive(Debug, FromAttributes)]
58#[darling(attributes(tdf), forward_attrs(allow, doc, cfg))]
59struct TdfEnumVariantAttr {
60    #[darling(default)]
61    default: bool,
62}
63
64#[derive(Debug, FromAttributes)]
65#[darling(attributes(tdf), forward_attrs(allow, doc, cfg))]
66struct TdfTaggedEnumVariantAttr {
67    pub key: Option<Expr>,
68
69    #[darling(default)]
70    pub tag: Option<DataTag>,
71
72    #[darling(default)]
73    pub prefix_two: bool,
74
75    #[darling(default)]
76    pub default: bool,
77
78    #[darling(default)]
79    pub unset: bool,
80}
81
82#[proc_macro_derive(TdfSerialize, attributes(tdf))]
83pub fn derive_tdf_serialize(input: TokenStream) -> TokenStream {
84    let input: DeriveInput = parse_macro_input!(input);
85
86    match &input.data {
87        syn::Data::Struct(data) => impl_serialize_struct(&input, data),
88        syn::Data::Enum(data) => {
89            if is_enum_tagged(data) {
90                impl_serialize_tagged_enum(&input, data)
91            } else {
92                impl_serialize_repr_enum(&input, data)
93            }
94        }
95        syn::Data::Union(_) => panic!("TdfSerialize cannot be implemented on union types"),
96    }
97}
98
99#[proc_macro_derive(TdfTyped, attributes(tdf))]
100pub fn derive_tdf_typed(input: TokenStream) -> TokenStream {
101    let input: DeriveInput = parse_macro_input!(input);
102
103    match &input.data {
104        syn::Data::Struct(data) => impl_type_struct(&input, data),
105        syn::Data::Enum(data) => {
106            if is_enum_tagged(data) {
107                impl_type_tagged_enum(&input, data)
108            } else {
109                impl_type_repr_enum(&input, data)
110            }
111        }
112        syn::Data::Union(_) => panic!("TdfTyped cannot be implemented on union types"),
113    }
114}
115
116#[proc_macro_derive(TdfDeserialize, attributes(tdf))]
117pub fn derive_tdf_deserialize(input: TokenStream) -> TokenStream {
118    let input: DeriveInput = parse_macro_input!(input);
119    match &input.data {
120        syn::Data::Struct(data) => impl_deserialize_struct(&input, data),
121        syn::Data::Enum(data) => {
122            if is_enum_tagged(data) {
123                impl_deserialize_tagged_enum(&input, data)
124            } else {
125                impl_deserialize_repr_enum(&input, data)
126            }
127        }
128
129        syn::Data::Union(_) => panic!("TdfDeserialize cannot be implemented on union types"),
130    }
131}
132
133fn get_repr_attribute(attrs: &[Attribute]) -> Option<Ident> {
134    attrs
135        .iter()
136        .filter_map(|attr| attr.meta.require_list().ok())
137        .find(|value| value.path.is_ident("repr"))
138        .map(|attr| {
139            let value: Ident = attr.parse_args().expect("Failed to parse repr type");
140            value
141        })
142}
143
144/// Determines whether an enum should be considered to be a Tagged Union
145/// rather than a repr enum. Any enum types that have fields cannot be
146/// repr types and thus must be Tagged Union's
147fn is_enum_tagged(data: &DataEnum) -> bool {
148    data.variants
149        .iter()
150        .any(|variant| !variant.fields.is_empty())
151}
152
153fn impl_type_struct(input: &DeriveInput, _data: &DataStruct) -> TokenStream {
154    let attr =
155        TdfStructAttr::from_attributes(&input.attrs).expect("Failed to parse tdf struct attrs");
156
157    assert!(
158        attr.group,
159        "Cannot derive TdfTyped on non group struct, type is unknown"
160    );
161
162    let ident = &input.ident;
163    let generics = &input.generics;
164    let where_clause = generics.where_clause.as_ref();
165
166    quote! {
167        impl #generics tdf::TdfTyped for #ident #generics #where_clause {
168            const TYPE: tdf::TdfType = tdf::TdfType::Group;
169        }
170    }
171    .into()
172}
173
174fn impl_type_repr_enum(input: &DeriveInput, _data: &DataEnum) -> TokenStream {
175    let ident = &input.ident;
176    let repr = get_repr_attribute(&input.attrs)
177        .expect("Non-tagged enums require #[repr({ty})] to be specified");
178
179    quote! {
180        impl tdf::TdfTyped for #ident {
181            const TYPE: tdf::TdfType = <#repr as tdf::TdfTyped>::TYPE;
182        }
183    }
184    .into()
185}
186
187fn impl_type_tagged_enum(input: &DeriveInput, _data: &DataEnum) -> TokenStream {
188    let ident = &input.ident;
189
190    let generics = &input.generics;
191    let where_clause = generics.where_clause.as_ref();
192
193    quote! {
194        impl #generics tdf::TdfTyped for #ident #generics #where_clause {
195            const TYPE: tdf::TdfType = tdf::TdfType::TaggedUnion;
196        }
197    }
198    .into()
199}
200
201fn tag_field_serialize(
202    field: &Field,
203    into: Option<Expr>,
204    tag: Option<DataTag>,
205    is_struct: bool,
206) -> proc_macro2::TokenStream {
207    let tag = tag.expect("Fields that arent skipped must specify a tag");
208    let ident = &field.ident;
209    let ty = &field.ty;
210
211    // TODO: Validate tag
212
213    let value = if is_struct {
214        quote!(&self.#ident)
215    } else {
216        quote!(#ident)
217    };
218
219    if let Some(into) = into {
220        quote!( w.tag_owned::<#into>(#tag, <#ty as Into::<#into>>::into(*#value)); )
221    } else {
222        quote! ( w.tag_ref::<#ty>(#tag, #value); )
223    }
224}
225
226fn impl_serialize_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
227    let attr =
228        TdfStructAttr::from_attributes(&input.attrs).expect("Failed to parse tdf struct attrs");
229    let ident = &input.ident;
230    let generics = &input.generics;
231    let where_clause = generics.where_clause.as_ref();
232
233    let serialize_impls = data.fields.iter().filter_map(|field| {
234        let attributes =
235            TdfFieldAttrs::from_attributes(&field.attrs).expect("Failed to parse tdf field attrs");
236        if attributes.skip {
237            None
238        } else {
239            Some(tag_field_serialize(
240                field,
241                attributes.into,
242                attributes.tag,
243                true,
244            ))
245        }
246    });
247
248    let mut leading = None;
249    let mut trailing = None;
250
251    if attr.group {
252        if attr.prefix_two {
253            leading = Some(quote! { w.write_byte(2); });
254        }
255
256        trailing = Some(quote!( w.tag_group_end();));
257    }
258
259    quote! {
260        impl #generics tdf::TdfSerialize for #ident #generics #where_clause {
261            fn serialize<S: tdf::TdfSerializer>(&self, w: &mut S) {
262                #leading
263                #(#serialize_impls)*
264                #trailing
265            }
266        }
267    }
268    .into()
269}
270
271fn impl_serialize_repr_enum(input: &DeriveInput, _data: &DataEnum) -> TokenStream {
272    let ident = &input.ident;
273    let repr = get_repr_attribute(&input.attrs)
274        .expect("Non-tagged enums require #[repr({ty})] to be specified");
275
276    quote! {
277        impl tdf::TdfSerializeOwned for #ident {
278            fn serialize_owned<S: tdf::TdfSerializer>(self, w: &mut S) {
279                <#repr as tdf::TdfSerializeOwned>::serialize_owned(self as #repr, w);
280            }
281        }
282
283        impl tdf::TdfSerialize for #ident {
284            #[inline]
285            fn serialize<S: tdf::TdfSerializer>(&self, w: &mut S) {
286               tdf::TdfSerializeOwned::serialize_owned(*self, w);
287            }
288        }
289    }
290    .into()
291}
292
293fn impl_serialize_tagged_enum(input: &DeriveInput, data: &DataEnum) -> TokenStream {
294    let ident = &input.ident;
295
296    let field_impls: Vec<_> = data
297        .variants
298        .iter()
299        .map(|variant| {
300            let attr: TdfTaggedEnumVariantAttr =
301                TdfTaggedEnumVariantAttr::from_attributes(&variant.attrs)
302                    .expect("Failed to parse tdf field attrs");
303
304            (variant, attr)
305        })
306        .map(|(variant, attr)| {
307            let var_ident = &variant.ident;
308            let value_tag = attr.tag;
309            let is_unit = attr.unset || attr.default;
310
311            // TODO: Ensure no duplicates & validate value tag matches
312
313            if let Fields::Unit = &variant.fields {
314                assert!(
315                    is_unit,
316                    "Only unset or default enum variants can have no content"
317                );
318
319                return quote! {
320                    Self::#var_ident => {
321                        w.write_byte(tdf::types::tagged_union::TAGGED_UNSET_KEY);
322                    }
323                };
324            }
325
326            assert!(
327                !is_unit,
328                "Enum variants with fields cannot be used as the default or unset variant"
329            );
330
331            let discriminant = attr.key.expect("Missing discriminant key");
332            let value_tag = value_tag.expect("Missing value tag");
333
334            match &variant.fields {
335                // Variants with named fields are handled as groups
336                Fields::Named(fields) => {
337                    let (idents, impls): (Vec<_>, Vec<_>) = fields
338                        .named
339                        .iter()
340                        .filter_map(|field| {
341                            let attributes = TdfFieldAttrs::from_attributes(&field.attrs)
342                                .expect("Failed to parse tdf field attrs");
343                            if attributes.skip {
344                                return None;
345                            }
346
347                            Some((field, attributes))
348                        })
349                        .map(|(field, attributes)| {
350                            let ident = field.ident.as_ref().expect("Field missing ident");
351                            let serialize = tag_field_serialize(field, attributes.into,attributes.tag, false);
352                            (ident, serialize)
353                        })
354                        .unzip();
355
356                    // Handle how field names are listed
357                    let field_names: proc_macro2::TokenStream = if idents.is_empty() {
358                        quote!(..)
359                    } else if idents.len() != fields.named.len() {
360                        quote!(#(#idents,)* ..)
361                    } else {
362                        quote!(#(#idents),*)
363                    };
364
365                    let mut leading = None;
366
367                    if attr.prefix_two {
368                        leading = Some(quote!( w.write_byte(2); ))
369                    }
370
371                    quote! {
372                        Self::#var_ident { #field_names } => {
373                            w.write_byte(#discriminant);
374                            tdf::Tagged::serialize_raw(w, #value_tag, tdf::TdfType::Group);
375
376                            #leading
377                            #(#impls)*
378                            w.tag_group_end();
379                        }
380                    }
381                }
382
383                // Variants with unnamed fields are treated as the type of the first field (Only one field is allowed)
384                Fields::Unnamed(fields) => {
385                    let fields = &fields.unnamed;
386                    let field = fields.first().expect("Unnamed tagged enum missing field");
387
388                    assert!(
389                        fields.len() == 1,
390                        "Tagged union cannot have more than one unnamed field"
391                    );
392
393                    let field_ty = &field.ty;
394
395                    quote! {
396                        Self::#var_ident(value) => {
397                            w.write_byte(#discriminant);
398                            tdf::Tagged::serialize_raw(w, #value_tag, <#field_ty as tdf::TdfTyped>::TYPE);
399
400                            <#field_ty as tdf::TdfSerialize>::serialize(value, w);
401                        }
402                    }
403                }
404                Fields::Unit => unreachable!("Unit types should already be handled above"),
405            }
406        })
407        .collect();
408    let generics = &input.generics;
409    let where_clause = generics.where_clause.as_ref();
410
411    quote! {
412        impl #generics tdf::TdfSerialize for #ident #generics #where_clause {
413            fn serialize<S: tdf::TdfSerializer>(&self, w: &mut S) {
414                match self {
415                    #(#field_impls),*
416                }
417            }
418        }
419    }
420    .into()
421}
422
423/// Obtains the lifetime that should be used by [tdf::TdfDeserializer] when
424/// deserializing values. If the generic parameters specify a lifetime then
425/// that lifetime is used otherwise the default lifetime '_ is used instead
426///
427/// Will panic if structure uses more than 1 lifetime as its not possible
428/// to deserialize with more than one lifetime
429fn get_deserialize_lifetime(generics: &Generics) -> LifetimeParam {
430    let mut lifetimes = generics.lifetimes();
431
432    let lifetime = lifetimes
433        .next()
434        .cloned()
435        // Use a default '_ lifetime while deserializing when no lifetime is provided
436        .unwrap_or_else(|| LifetimeParam::new(Lifetime::new("'_", Span::call_site().into())));
437
438    assert!(
439        lifetimes.next().is_none(),
440        "Deserializable structs cannot have more than one lifetime"
441    );
442
443    lifetime
444}
445
446/// Creates a token stream for deserializing the provided `field`
447/// loads the tag and whether
448fn tag_field_deserialize(field: &Field) -> proc_macro2::TokenStream {
449    let attributes =
450        TdfFieldAttrs::from_attributes(&field.attrs).expect("Failed to parse tdf field attrs");
451
452    let ident = &field.ident;
453    let ty = &field.ty;
454
455    if attributes.skip {
456        quote!( let #ident = Default::default(); )
457    } else {
458        let tag = attributes
459            .tag
460            .expect("Fields that arent skipped must specify a tag");
461
462        // TODO: Validate tag
463        if let Some(into) = attributes.into {
464            quote!( let #ident = <#ty as From<#into>>::from(r.tag::<#into>(#tag)?); )
465        } else {
466            quote!( let #ident = r.tag::<#ty>(#tag)?; )
467        }
468    }
469}
470
471fn impl_deserialize_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
472    let attributes =
473        TdfStructAttr::from_attributes(&input.attrs).expect("Failed to parse tdf struct attrs");
474
475    let ident = &input.ident;
476
477    let generics = &input.generics;
478    let lifetime = get_deserialize_lifetime(generics);
479    let where_clause = generics.where_clause.as_ref();
480
481    let idents = data.fields.iter().filter_map(|field| field.ident.as_ref());
482    let impls = data.fields.iter().map(tag_field_deserialize);
483
484    let mut trailing = None;
485
486    // Groups need leading deserialization for possible prefixes and trailing
487    // deserialization to read any unused tags and to read the group end byte
488    if attributes.group {
489        trailing = Some(quote!( tdf::GroupSlice::deserialize_content_skip(r)?; ));
490    }
491
492    quote! {
493        impl #generics tdf::TdfDeserialize<#lifetime> for #ident #generics #where_clause {
494            fn deserialize(r: &mut tdf::TdfDeserializer<#lifetime>) -> tdf::DecodeResult<Self> {
495                #(#impls)*
496                #trailing
497                Ok(Self {
498                    #(#idents),*
499                })
500            }
501        }
502    }
503    .into()
504}
505
506fn impl_deserialize_repr_enum(input: &DeriveInput, data: &DataEnum) -> TokenStream {
507    let repr = get_repr_attribute(&input.attrs)
508        .expect("Non-tagged enums require #[repr({ty})] to be specified");
509
510    let mut default = None;
511
512    let variant_cases: Vec<_> = data
513        .variants
514        .iter()
515        .map(|variant| {
516            let attr = TdfEnumVariantAttr::from_attributes(&variant.attrs)
517                .expect("Failed to parse tdf enum variant attrs");
518            (variant, attr)
519        })
520        .filter(|(variant, attr)| {
521            if !attr.default {
522                return true;
523            }
524
525            assert!(
526                default.is_none(),
527                "Cannot have more than one default variant"
528            );
529
530            let ident = &variant.ident;
531
532            default = Some(quote!(_ => Self::#ident));
533
534            false
535        })
536        .map(|(variant, _attr)| {
537            let var_ident = &variant.ident;
538            let (_, discriminant) = variant
539                .discriminant
540                .as_ref()
541                .expect("Repr enum variants must include a descriminant for each value");
542
543            quote! ( #discriminant => Self::#var_ident )
544        })
545        .collect();
546
547    let ident = &input.ident;
548    let default = default.unwrap_or_else(
549        || quote!(_ => return Err(tdf::DecodeError::Other("Missing fallback enum variant"))),
550    );
551
552    quote! {
553        impl tdf::TdfDeserialize<'_> for #ident {
554            fn deserialize(r: &mut tdf::TdfDeserializer<'_>) -> tdf::DecodeResult<Self> {
555                let value = <#repr as tdf::TdfDeserialize<'_>>::deserialize(r)?;
556                Ok(match value {
557                    #(#variant_cases,)*
558                    #default
559                })
560            }
561        }
562    }
563    .into()
564}
565
566fn impl_deserialize_tagged_enum(input: &DeriveInput, data: &DataEnum) -> TokenStream {
567    let generics = &input.generics;
568    let lifetime = get_deserialize_lifetime(generics);
569    let where_clause = generics.where_clause.as_ref();
570
571    let mut has_unset = false;
572    let mut has_default = false;
573
574    let mut impls: Punctuated<proc_macro2::TokenStream, Comma> = data
575        .variants
576        .iter()
577        .map(|variant| {
578            let attr: TdfTaggedEnumVariantAttr =
579                TdfTaggedEnumVariantAttr::from_attributes(&variant.attrs)
580                    .expect("Failed to parse tdf field attrs");
581
582            let var_ident = &variant.ident;
583            let is_unit = attr.unset || attr.default;
584
585            if let Fields::Unit = &variant.fields {
586                assert!(
587                    is_unit,
588                    "Only unset or default enum variants can have no content"
589                );
590
591                assert!(
592                    !(attr.default && attr.unset),
593                    "Enum variant cannot be default and unset"
594                );
595
596                return if attr.default {
597                    assert!(!has_default, "Default variant already defined");
598                    has_default = true;
599
600                    quote! {
601                        _  => {
602                            let tag = tdf::Tagged::deserialize_owned(r)?;
603                            tag.ty.skip(r, false)?;
604                            Self::#var_ident
605                        }
606                    }
607                } else {
608                    assert!(!has_unset, "Unset variant already defined");
609                    has_unset = true;
610                    quote!( tdf::types::tagged_union::TAGGED_UNSET_KEY => Self::#var_ident )
611                };
612            }
613
614            assert!(
615                !is_unit,
616                "Enum variants with fields cannot be used as the default or unset variant"
617            );
618
619            let discriminant = attr.key.expect("Missing discriminant key");
620            let _value_tag = attr.tag.expect("Missing value tag");
621
622            match &variant.fields {
623                // Variants with named fields are handled as groups
624                Fields::Named(fields) => {
625                    let (idents, impls): (Vec<_>, Vec<_>) = fields
626                        .named
627                        .iter()
628                        .map(|field| {
629                            let ident = field.ident.as_ref().unwrap();
630                            let value = tag_field_deserialize(field);
631                            (ident, value)
632                        })
633                        .unzip();
634
635                    quote! {
636                        #discriminant => {
637                            let tag = tdf::Tagged::deserialize_owned(r)?;
638
639                            #(#impls)*
640                            tdf::GroupSlice::deserialize_content_skip(r)?;
641
642                            Self::#var_ident {
643                                #(#idents),*
644                            }
645                        }
646                    }
647                }
648                // Variants with unnamed fields are treated as the type of the first field (Only one field is allowed)
649                Fields::Unnamed(fields) => {
650                    let fields = &fields.unnamed;
651                    let field = fields.first().expect("Unnamed tagged enum missing field");
652
653                    assert!(
654                        fields.len() == 1,
655                        "Tagged union cannot have more than one unnamed field"
656                    );
657
658                    let field_ty = &field.ty;
659
660                    quote! {
661                        #discriminant => {
662                            let tag = tdf::Tagged::deserialize_owned(r)?;
663
664                            let value = <#field_ty as tdf::TdfDeserialize<'_>>::deserialize(r)?;
665                            Self::#var_ident(value)
666                        }
667                    }
668                }
669
670                Fields::Unit => unreachable!("Unit types should already be handled above"),
671            }
672        })
673        .collect();
674
675    if !has_unset {
676        // If an unset variant is not specified its handling is replaced with a runtime error
677        impls.push(quote!(
678            tdf::types::tagged_union::TAGGED_UNSET_KEY => return Err(tdf::DecodeError::Other("Missing unset enum variant"))
679        ));
680    }
681
682    if !has_default {
683        // If a default variant is not specified its handling is replaced with a runtime error
684        impls.push(quote!(
685            _ => return Err(tdf::DecodeError::Other("Missing default enum variant"))
686        ));
687    }
688
689    let ident = &input.ident;
690
691    quote! {
692        impl #generics tdf::TdfDeserialize<#lifetime> for #ident #generics #where_clause {
693            fn deserialize(r: &mut tdf::TdfDeserializer<#lifetime>) -> tdf::DecodeResult<Self> {
694                let discriminant = <u8 as tdf::TdfDeserialize<#lifetime>>::deserialize(r)?;
695
696                Ok(match discriminant {
697                    #impls
698                })
699            }
700        }
701    }
702    .into()
703}