bilrost_derive/
lib.rs

1#![doc(html_root_url = "https://docs.rs/bilrost-derive/0.1014.1")]
2// The `quote!` macro requires deep recursion.
3#![recursion_limit = "4096"]
4#![no_std]
5
6//! This crate contains the derive macro implementations for the
7//! [`bilrost`][bilrost] crate; see the documentation in that crate for usage and
8//! details.
9//!
10//! [bilrost]: https://docs.rs/bilrost
11
12extern crate alloc;
13
14use crate::attrs::{bilrost_attrs, set_bool, set_option, tag_list_attr, word_attr, TagList};
15use crate::field::traits::{
16    DecodeLifetime::{Borrowed, Owned},
17    DecodeMode::{Distinguished, Relaxed},
18    FieldBearer, SinglyTagged, Tagged,
19    WhereFor::{self, Decode, Encode},
20};
21use crate::field::{
22    parse_message_fields, tag_measurer, Field, FieldTarget, InitMode, MessageFieldsSorted,
23    OneofVariant,
24};
25use alloc::collections::BTreeMap;
26use alloc::string::ToString;
27use alloc::vec;
28use alloc::vec::Vec;
29use eyre::{bail, eyre as err, Report as Error};
30use itertools::Itertools;
31use proc_macro2::TokenStream;
32use quote::{quote, ToTokens};
33use syn::{
34    parse2, Attribute, Data, DeriveInput, Expr, Fields, Generics, Ident, Meta, Pat, TypeGenerics,
35    Variant, WhereClause,
36};
37
38mod attrs;
39mod field;
40
41fn crate_name() -> TokenStream {
42    quote!(::bilrost)
43}
44
45/// Defines the common aliases for encoder types available to every bilrost derive.
46///
47/// The standard encoders are all made available in scope with lower-cased names, making them
48/// simultaneously easier to spell when writing the field attributes and making them less likely to
49/// shadow custom encoder types.
50fn encoder_alias_header() -> TokenStream {
51    let crate_ = crate_name();
52    quote! {
53        use #crate_::encoding::{
54            Fixed as fixed,
55            General as general,
56            GeneralPacked as general_packed,
57            Map as map,
58            Packed as packed,
59            PlainBytes as plainbytes,
60            Unpacked as unpacked,
61            Varint as varint,
62        };
63    }
64}
65
66/// Combines an optional WhereClause and any number of additional provided where term(s) into a
67/// phrase that will always be a valid where clause if present.
68fn append_wheres(
69    where_clause: Option<&WhereClause>,
70    wheres: impl IntoIterator<Item = TokenStream>,
71) -> Option<TokenStream> {
72    // dedup the where clauses by their String values
73    let where_terms: BTreeMap<_, _> = wheres
74        .into_iter()
75        .chain(
76            where_clause
77                .into_iter()
78                .flat_map(|w| w.predicates.iter().map(|term| quote!(#term))),
79        )
80        .map(|where_| (where_.to_string(), where_))
81        .collect();
82    // append our encoder where terms to the existing where clause if there is one
83    if where_terms.is_empty() {
84        None
85    } else {
86        let each_where_term = where_terms.values();
87        Some(quote! { where #(#each_where_term,)* })
88    }
89}
90
91/// Combines an optional where clause with additional terms for each field's encoder to assert that
92/// it supports the field's type.
93fn append_wheres_with_fields(
94    where_clause: Option<&WhereClause>,
95    wheres: impl IntoIterator<Item = TokenStream>,
96    fields: impl FieldBearer,
97    field_purpose: WhereFor,
98) -> Option<TokenStream> {
99    append_wheres(
100        where_clause,
101        wheres.into_iter().chain(fields.where_terms(field_purpose)),
102    )
103}
104
105/// Adds the given identifier to the generics list
106fn prepend_to_generics(generics: &Generics, ident: TokenStream) -> TokenStream {
107    let params = &generics.params;
108    quote!(<#ident, #params>)
109}
110
111fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
112    let crate_ = crate_name();
113    let input: DeriveInput = parse2(input)?;
114
115    let DeriveInput {
116        ident,
117        attrs: input_attrs,
118        generics: impl_generics,
119        data: Data::Struct(data_struct),
120        ..
121    } = input
122    else {
123        // `enum` types are only derived as `Message` in terms of their `Oneof` implementation
124        if matches!(input.data, Data::Enum(..)) {
125            return try_message_via_oneof(input);
126        } else {
127            bail!("Message can only be derived for a struct or an enum");
128        }
129    };
130
131    // Process attributes
132    let mut reserved_tags: Option<TagList> = None;
133    let mut distinguished = false;
134    let mut borrow_only = false;
135    let mut default_per_field = false;
136    let mut unknown_attrs = Vec::new();
137    for attr in bilrost_attrs(&input_attrs)? {
138        if let Some(tags) = tag_list_attr(&attr, "reserved_tags", None)? {
139            set_option(
140                &mut reserved_tags,
141                tags,
142                "duplicate reserved_tags attributes",
143            )?;
144        } else if word_attr(&attr, "distinguished") {
145            set_bool(&mut distinguished, "duplicated distinguished attributes")?;
146        } else if word_attr(&attr, "borrowed_only") {
147            set_bool(&mut borrow_only, "duplicated borrowed_only attributes")?;
148        } else if word_attr(&attr, "default_per_field") {
149            set_bool(
150                &mut default_per_field,
151                "duplicated default_per_field attributes",
152            )?;
153        } else {
154            unknown_attrs.push(attr);
155        }
156    }
157
158    if !unknown_attrs.is_empty() {
159        bail!(
160            "unknown attribute(s) for message: {attrs}",
161            attrs = quote!(#(#unknown_attrs),*),
162        )
163    }
164
165    let init_mode = match default_per_field {
166        true => InitMode::DefaultPerField,
167        false => InitMode::ParentDefault,
168    };
169
170    // Parse field data
171    let (ignored_fields, unsorted_fields): (Vec<_>, Vec<_>) =
172        parse_message_fields(data_struct.fields, init_mode, reserved_tags)?
173            .into_iter()
174            .partition(Field::is_ignored);
175
176    if distinguished && !ignored_fields.is_empty() {
177        bail!("messages with ignored fields cannot be distinguished");
178    }
179
180    let (_, ty_generics, where_clause) = impl_generics.split_for_impl();
181
182    let self_where = if default_per_field || ignored_fields.is_empty() {
183        None
184    } else {
185        // When there are ignored fields that we are taking from <Self as Default>, the whole
186        // message impl should be bounded by Self: Default
187        Some(quote!(Self: ::core::default::Default))
188    };
189
190    let borrow_generics = prepend_to_generics(&impl_generics, quote!('__a));
191
192    let where_fields = vec![unsorted_fields.as_slice(), ignored_fields.as_slice()];
193    let encoder_where_clause =
194        append_wheres_with_fields(where_clause, self_where.clone(), &where_fields, Encode);
195    let [owned_decoder_where_clause, borrowed_decoder_where_clause] =
196        [Owned, Borrowed].map(|lifetime| {
197            append_wheres_with_fields(
198                where_clause,
199                self_where.clone(),
200                &where_fields,
201                Decode(lifetime, Relaxed),
202            )
203        });
204
205    let self_instance = FieldTarget::MessageInstance(quote!(self));
206    let fields = MessageFieldsSorted::new(&unsorted_fields);
207    let encoded_len = fields.encoded_len(&self_instance);
208    let encode = fields.encode(&self_instance);
209    let prepend = fields.prepend(&self_instance);
210
211    let [decode_owned, decode_borrowed] = [Owned, Borrowed].map(|lifetime| {
212        let ident_str = ident.to_string();
213        let self_instance = self_instance.clone();
214        unsorted_fields.iter().map(move |field| {
215            let decode = field.decode(&self_instance, lifetime, Relaxed);
216            let tags = field.tags().into_iter().map(|tag| quote!(#tag));
217            let tags = Itertools::intersperse(tags, quote!(|));
218            let field_ident_str = field.ident().to_string();
219
220            quote! {
221                #(#tags)* => {
222                    if let ::core::result::Result::Err(mut error) = #decode {
223                        error.push(#ident_str, #field_ident_str);
224                        return ::core::result::Result::Err(error);
225                    }
226                }
227            }
228        })
229    });
230
231    let methods = unsorted_fields
232        .iter()
233        .flat_map(|field| field.methods())
234        .collect::<Vec<_>>();
235    let methods = if methods.is_empty() {
236        None
237    } else {
238        Some(quote! {
239            #[allow(dead_code)]
240            impl #impl_generics __Self #ty_generics #encoder_where_clause {
241                #(#methods)*
242            }
243        })
244    };
245
246    let static_guards = unsorted_fields
247        .iter()
248        .filter_map(|field| field.tag_list_guard());
249
250    let empties: Vec<_> = unsorted_fields
251        .iter()
252        .chain(ignored_fields.iter())
253        .flat_map(|field| {
254            let empty = field.empty()?;
255            let ident = field.ident();
256            Some(quote!(#ident: #empty))
257        })
258        .collect();
259    let is_empties: Vec<_> = unsorted_fields
260        .iter()
261        .map(|field| field.is_empty(&self_instance))
262        .collect();
263    let clears: Vec<_> = unsorted_fields
264        .iter()
265        .map(|field| field.clear(&self_instance))
266        .collect();
267
268    let maybe_fill_default = if default_per_field || ignored_fields.is_empty() {
269        None
270    } else {
271        // initialize ignored fields from <Self as Default>
272        Some(quote!(..::core::default::Default::default()))
273    };
274
275    let impl_owned_decoder = (!borrow_only).then(|| {
276        quote! {
277            impl #impl_generics #crate_::encoding::RawMessageDecoder
278            for __Self #ty_generics #owned_decoder_where_clause {
279                #[allow(unused_variables)]
280                #[inline]
281                fn raw_decode_field<__B>(
282                    &mut self,
283                    tag: u32,
284                    wire_type: #crate_::encoding::WireType,
285                    duplicated: bool,
286                    buf: #crate_::encoding::Capped<__B>,
287                    ctx: #crate_::encoding::DecodeContext,
288                ) -> ::core::result::Result<(), #crate_::DecodeError>
289                where
290                    __B: #crate_::bytes::Buf + ?Sized,
291                {
292                    let _ = <Self as #crate_::encoding::RawMessage>::__ASSERTIONS;
293                    match tag {
294                        #(#decode_owned)*
295                        _ => #crate_::encoding::skip_field(wire_type, buf)?,
296                    }
297                    ::core::result::Result::Ok(())
298                }
299            }
300        }
301    });
302
303    // The static guards should be instantiated within each of the methods of the trait; in newer
304    // versions of rust simply instantiating a variable in any method with `let` is enough to cause
305    // the assertions to be evaluated, but in older versions the evaluation might not happen unless
306    // there is an actual code path that invokes the function.
307    //
308    // Even in rust 1.79 nightly, if the constant is never named anywhere the assertions won't
309    // actually run.
310    let impls = quote! {
311        impl #impl_generics #crate_::encoding::RawMessage
312        for __Self #ty_generics #encoder_where_clause {
313            const __ASSERTIONS: () = { #(#static_guards)* };
314
315            fn empty() -> Self {
316                Self {
317                    #(#empties,)*
318                    #maybe_fill_default
319                }
320            }
321
322            fn is_empty(&self) -> bool {
323                true #(&& #is_empties)*
324            }
325
326            fn clear(&mut self) {
327                #(#clears)*
328            }
329
330            #[allow(unused_variables)]
331            fn raw_encode<__B>(&self, buf: &mut __B)
332            where
333                __B: #crate_::bytes::BufMut + ?Sized,
334            {
335                let _ = <Self as #crate_::encoding::RawMessage>::__ASSERTIONS;
336                #encode
337            }
338
339            #[allow(unused_variables)]
340            fn raw_prepend<__B>(&self, buf: &mut __B)
341            where
342                __B: #crate_::buf::ReverseBuf + ?Sized,
343            {
344                let _ = <Self as #crate_::encoding::RawMessage>::__ASSERTIONS;
345                #prepend
346            }
347
348            #[inline]
349            fn raw_encoded_len(&self) -> usize {
350                let _ = <Self as #crate_::encoding::RawMessage>::__ASSERTIONS;
351                #encoded_len
352            }
353        }
354
355        #impl_owned_decoder
356
357        impl #borrow_generics #crate_::encoding::RawMessageBorrowDecoder<'__a>
358        for __Self #ty_generics #borrowed_decoder_where_clause {
359            #[allow(unused_variables)]
360            #[inline]
361            fn raw_borrow_decode_field(
362                &mut self,
363                tag: u32,
364                wire_type: #crate_::encoding::WireType,
365                duplicated: bool,
366                buf: #crate_::encoding::Capped<&'__a [u8]>,
367                ctx: #crate_::encoding::DecodeContext,
368            ) -> ::core::result::Result<(), #crate_::DecodeError> {
369                let _ = <Self as #crate_::encoding::RawMessage>::__ASSERTIONS;
370                match tag {
371                    #(#decode_borrowed)*
372                    _ => #crate_::encoding::skip_field(wire_type, buf)?,
373                }
374                ::core::result::Result::Ok(())
375            }
376        }
377
378        impl #impl_generics #crate_::encoding::ForOverwrite<(), __Self #ty_generics> for ()
379        #encoder_where_clause {
380            fn for_overwrite() -> __Self #ty_generics {
381                <__Self #ty_generics as #crate_::encoding::RawMessage>::empty()
382            }
383        }
384
385        impl #impl_generics #crate_::encoding::EmptyState<(), __Self #ty_generics> for ()
386        #encoder_where_clause {
387            fn is_empty(val: &__Self #ty_generics) -> bool {
388                <__Self #ty_generics as #crate_::encoding::RawMessage>::is_empty(val)
389            }
390
391            fn clear(val: &mut __Self #ty_generics) {
392                <__Self #ty_generics as #crate_::encoding::RawMessage>::clear(val);
393            }
394        }
395    };
396
397    let distinguished_impls = distinguished.then(|| {
398        // At time of commenting distinguished mode precludes any additional self-bounds, as there
399        // cannot be ignored fields. If we add any in the future, we want to catch that
400        // automatically.
401        let distinguished_self_where = [quote!(Self: ::core::cmp::Eq)]
402            .into_iter()
403            .chain(self_where);
404        let [owned_decoder_where_clause, borrowed_decoder_where_clause] =
405            [Owned, Borrowed].map(|lifetime| {
406                append_wheres_with_fields(
407                    where_clause,
408                    distinguished_self_where.clone(),
409                    &where_fields,
410                    Decode(lifetime, Distinguished),
411                )
412            });
413
414        let [decode_owned, decode_borrowed] = [Owned, Borrowed].map(|lifetime| {
415            let ident_str = ident.to_string();
416            let self_instance = self_instance.clone();
417            unsorted_fields.iter().map(move |field| {
418                let decode = field.decode(&self_instance, lifetime, Distinguished);
419                let tags = field.tags().into_iter().map(|tag| quote!(#tag));
420                let tags = Itertools::intersperse(tags, quote!(|));
421                let field_ident_str = field.ident().to_string();
422
423                quote! {
424                    #(#tags)* => {
425                        match #decode {
426                            ::core::result::Result::Ok(new_canon) => {
427                                canon.update(new_canon);
428                            }
429                            ::core::result::Result::Err(mut error) => {
430                                error.push(#ident_str, #field_ident_str);
431                                return ::core::result::Result::Err(error);
432                            }
433                        }
434                    }
435                }
436            })
437        });
438
439        let impl_owned_decoder = (!borrow_only).then(|| {
440            quote! {
441                impl #impl_generics #crate_::encoding::RawDistinguishedMessageDecoder
442                for __Self #ty_generics #owned_decoder_where_clause {
443                    #[allow(unused_variables)]
444                    #[inline]
445                    fn raw_decode_field_distinguished<__B>(
446                        &mut self,
447                        tag: u32,
448                        wire_type: #crate_::encoding::WireType,
449                        duplicated: bool,
450                        buf: #crate_::encoding::Capped<__B>,
451                        ctx: #crate_::encoding::RestrictedDecodeContext,
452                    ) -> ::core::result::Result<#crate_::Canonicity, #crate_::DecodeError>
453                    where
454                        __B: #crate_::bytes::Buf + ?Sized,
455                    {
456                        let mut canon = #crate_::Canonicity::Canonical;
457                        match tag {
458                            #(#decode_owned)*
459                            _ => {
460                                canon.update(ctx.check(#crate_::Canonicity::HasExtensions)?);
461                                #crate_::encoding::skip_field(wire_type, buf)?;
462                            }
463                        }
464                        ::core::result::Result::Ok(canon)
465                    }
466                }
467            }
468        });
469
470        quote! {
471            #impl_owned_decoder
472
473            impl #borrow_generics #crate_::encoding::RawDistinguishedMessageBorrowDecoder<'__a>
474            for __Self #ty_generics #borrowed_decoder_where_clause {
475                #[allow(unused_variables)]
476                #[inline]
477                fn raw_borrow_decode_field_distinguished(
478                    &mut self,
479                    tag: u32,
480                    wire_type: #crate_::encoding::WireType,
481                    duplicated: bool,
482                    buf: #crate_::encoding::Capped<&'__a [u8]>,
483                    ctx: #crate_::encoding::RestrictedDecodeContext,
484                ) -> ::core::result::Result<#crate_::Canonicity, #crate_::DecodeError> {
485                    let canon = &mut #crate_::Canonicity::Canonical;
486                    match tag {
487                        #(#decode_borrowed)*
488                        _ => {
489                            canon.update(ctx.check(#crate_::Canonicity::HasExtensions)?);
490                            #crate_::encoding::skip_field(wire_type, buf)?;
491                        }
492                    }
493                    ::core::result::Result::Ok(*canon)
494                }
495            }
496        }
497    });
498
499    let aliases = encoder_alias_header();
500    let expanded = quote! {
501        const _: () = {
502            use #ident as __Self;
503
504            const _: () = {
505                #aliases
506
507                #impls
508
509                #distinguished_impls
510
511                #methods
512            };
513        };
514    };
515
516    Ok(expanded)
517}
518
519fn try_message_via_oneof(input: DeriveInput) -> Result<TokenStream, Error> {
520    let crate_ = crate_name();
521    let PreprocessedOneof {
522        ident,
523        impl_generics,
524        ty_generics,
525        where_clause,
526        variants,
527        distinguished,
528        borrow_only,
529        empty_variant,
530    } = preprocess_oneof(&input)?;
531
532    let tag_measurer_ty = tag_measurer(&variants);
533
534    if empty_variant.is_none() {
535        bail!("Message can only be derived for Oneof enums that have an empty variant.")
536    }
537
538    let borrow_generics = prepend_to_generics(impl_generics, quote!('__a));
539
540    let encoder_where_clause = append_wheres(
541        where_clause,
542        [quote!(#ident #ty_generics: #crate_::encoding::Oneof)],
543    );
544    let owned_decoder_where_clause = append_wheres(
545        where_clause,
546        [quote!(#ident #ty_generics: #crate_::encoding::OneofDecoder)],
547    );
548    let borrowed_decoder_where_clause = append_wheres(
549        where_clause,
550        [quote!(#ident #ty_generics: #crate_::encoding::OneofBorrowDecoder<'__a>)],
551    );
552
553    let impl_owned_decoder = (!borrow_only).then(|| {
554        quote! {
555            impl #impl_generics #crate_::encoding::RawMessageDecoder
556            for #ident #ty_generics #owned_decoder_where_clause {
557                #[inline(always)]
558                fn raw_decode_field<__B>(
559                    &mut self,
560                    tag: u32,
561                    wire_type: #crate_::encoding::WireType,
562                    _duplicated: bool,
563                    buf: #crate_::encoding::Capped<__B>,
564                    ctx: #crate_::encoding::DecodeContext,
565                ) -> ::core::result::Result<(), #crate_::DecodeError>
566                where
567                    __B: #crate_::bytes::Buf + ?Sized,
568                {
569                    if <Self as #crate_::encoding::Oneof>::FIELD_TAGS.contains(&tag) {
570                        <Self as #crate_::encoding::OneofDecoder>::oneof_decode_field(
571                            self,
572                            tag,
573                            wire_type,
574                            buf,
575                            ctx,
576                        )
577                    } else {
578                        #crate_::encoding::skip_field(wire_type, buf)
579                    }
580                }
581            }
582        }
583    });
584
585    let impls = quote! {
586        impl #impl_generics #crate_::encoding::RawMessage
587        for #ident #ty_generics #encoder_where_clause {
588            const __ASSERTIONS: () = ();
589
590            #[inline(always)]
591            fn empty() -> Self {
592                <Self as #crate_::encoding::Oneof>::empty()
593            }
594
595            #[inline(always)]
596            fn is_empty(&self) -> bool {
597                <Self as #crate_::encoding::Oneof>::is_empty(self)
598            }
599
600            #[inline(always)]
601            fn clear(&mut self) {
602                <Self as #crate_::encoding::Oneof>::clear(self)
603            }
604
605            #[inline(always)]
606            fn raw_encode<__B>(&self, buf: &mut __B)
607            where
608                __B: #crate_::bytes::BufMut + ?Sized,
609            {
610                <Self as #crate_::encoding::Oneof>::oneof_encode(
611                    self,
612                    buf,
613                    &mut #crate_::encoding::TagWriter::new(),
614                );
615            }
616
617            #[inline(always)]
618            fn raw_prepend<__B>(&self, buf: &mut __B)
619            where
620                __B: #crate_::buf::ReverseBuf + ?Sized,
621            {
622                let tw = &mut #crate_::encoding::TagRevWriter::new();
623                <Self as #crate_::encoding::Oneof>::oneof_prepend(self, buf, tw);
624                tw.finalize(buf);
625            }
626
627            #[inline(always)]
628            fn raw_encoded_len(&self) -> usize {
629                <Self as #crate_::encoding::Oneof>::oneof_encoded_len(
630                    self,
631                    &mut #tag_measurer_ty::new(),
632                )
633            }
634        }
635
636        impl #impl_generics #crate_::encoding::ForOverwrite<(), #ident #ty_generics> for ()
637        #encoder_where_clause {
638            #[inline(always)]
639            fn for_overwrite() -> #ident #ty_generics {
640                <#ident #ty_generics as #crate_::encoding::Oneof>::empty()
641            }
642        }
643
644        impl #impl_generics #crate_::encoding::EmptyState<(), #ident #ty_generics> for ()
645        #encoder_where_clause {
646            #[inline(always)]
647            fn is_empty(val: &#ident #ty_generics) -> bool {
648                <#ident #ty_generics as #crate_::encoding::Oneof>::is_empty(val)
649            }
650
651            #[inline(always)]
652            fn clear(val: &mut #ident #ty_generics) {
653                <#ident #ty_generics as #crate_::encoding::Oneof>::clear(val);
654            }
655        }
656
657        #impl_owned_decoder
658
659        impl #borrow_generics #crate_::encoding::RawMessageBorrowDecoder<'__a>
660        for #ident #ty_generics #borrowed_decoder_where_clause {
661            #[inline(always)]
662            fn raw_borrow_decode_field(
663                &mut self,
664                tag: u32,
665                wire_type: #crate_::encoding::WireType,
666                _duplicated: bool,
667                buf: #crate_::encoding::Capped<&'__a [u8]>,
668                ctx: #crate_::encoding::DecodeContext,
669            ) -> ::core::result::Result<(), #crate_::DecodeError> {
670                if <Self as #crate_::encoding::Oneof>::FIELD_TAGS.contains(&tag) {
671                    <Self as #crate_::encoding::OneofBorrowDecoder>::oneof_borrow_decode_field(
672                        self,
673                        tag,
674                        wire_type,
675                        buf,
676                        ctx,
677                    )
678                } else {
679                    #crate_::encoding::skip_field(wire_type, buf)
680                }
681            }
682        }
683    };
684
685    let distinguished_impls = distinguished.then(|| {
686        let owned_decoder_where_clause = append_wheres(
687            where_clause,
688            [quote!(
689                Self: #crate_::encoding::DistinguishedOneofDecoder + ::core::cmp::Eq
690            )],
691        );
692        let borrowed_decoder_where_clause = append_wheres(
693            where_clause,
694            [quote!(
695                Self: #crate_::encoding::DistinguishedOneofBorrowDecoder<'__a> + ::core::cmp::Eq
696            )],
697        );
698
699        let impl_owned_decoder = (!borrow_only).then(|| {
700            quote! {
701                impl #impl_generics #crate_::encoding::RawDistinguishedMessageDecoder
702                for #ident #ty_generics #owned_decoder_where_clause {
703                    #[inline(always)]
704                    fn raw_decode_field_distinguished<__B>(
705                        &mut self,
706                        tag: u32,
707                        wire_type: #crate_::encoding::WireType,
708                        _duplicated: bool,
709                        buf: #crate_::encoding::Capped<__B>,
710                        ctx: #crate_::encoding::RestrictedDecodeContext,
711                    ) -> ::core::result::Result<#crate_::Canonicity, #crate_::DecodeError>
712                    where
713                        __B: #crate_::bytes::Buf + ?Sized,
714                    {
715                        if <Self as #crate_::encoding::Oneof>::FIELD_TAGS.contains(&tag) {
716                            <Self as #crate_::encoding::DistinguishedOneofDecoder>::
717                                oneof_decode_field_distinguished
718                            (
719                                self,
720                                tag,
721                                wire_type,
722                                buf,
723                                ctx,
724                            )
725                        } else {
726                            _ = ctx.check(#crate_::Canonicity::HasExtensions)?;
727                            #crate_::encoding::skip_field(wire_type, buf)?;
728                            ::core::result::Result::Ok(#crate_::Canonicity::HasExtensions)
729                        }
730                    }
731                }
732            }
733        });
734
735        quote! {
736            #impl_owned_decoder
737
738            impl #borrow_generics #crate_::encoding::RawDistinguishedMessageBorrowDecoder<'__a>
739            for #ident #ty_generics #borrowed_decoder_where_clause {
740                #[inline(always)]
741                fn raw_borrow_decode_field_distinguished(
742                    &mut self,
743                    tag: u32,
744                    wire_type: #crate_::encoding::WireType,
745                    _duplicated: bool,
746                    buf: #crate_::encoding::Capped<&'__a [u8]>,
747                    ctx: #crate_::encoding::RestrictedDecodeContext,
748                ) -> ::core::result::Result<#crate_::Canonicity, #crate_::DecodeError> {
749                    if <Self as #crate_::encoding::Oneof>::FIELD_TAGS.contains(&tag) {
750                        <Self as #crate_::encoding::DistinguishedOneofBorrowDecoder>::
751                            oneof_borrow_decode_field_distinguished
752                        (
753                            self,
754                            tag,
755                            wire_type,
756                            buf,
757                            ctx,
758                        )
759                    } else {
760                        _ = ctx.check(#crate_::Canonicity::HasExtensions)?;
761                        #crate_::encoding::skip_field(wire_type, buf)?;
762                        ::core::result::Result::Ok(#crate_::Canonicity::HasExtensions)
763                    }
764                }
765            }
766        }
767    });
768
769    Ok(quote!(
770        #impls
771
772        #distinguished_impls
773    ))
774}
775
776#[proc_macro_derive(Message, attributes(bilrost))]
777pub fn message(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
778    try_message(input.into()).unwrap().into()
779}
780
781fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
782    let crate_ = crate_name();
783    let input: DeriveInput = parse2(input)?;
784    let ident = input.ident;
785
786    let generics = &input.generics;
787    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
788    let unborrowed_generics = prepend_to_generics(generics, quote!(const __G: u8));
789    let borrow_generics = prepend_to_generics(generics, quote!('__a, const __G: u8));
790
791    let punctuated_variants = match input.data {
792        Data::Enum(enum_) => enum_.variants,
793        Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
794        Data::Union(..) => bail!("Enumeration can not be derived for a union"),
795    };
796
797    struct EnumVariant {
798        variant_ident: Ident,
799        discriminant_expr: Expr,
800    }
801
802    // Parse each variant
803    let mut variants = vec![];
804    for Variant {
805        attrs,
806        ident: variant_ident,
807        fields,
808        discriminant,
809        ..
810    } in punctuated_variants
811    {
812        if !match fields {
813            Fields::Unit => true,
814            Fields::Named(named) => named.named.is_empty(),
815            Fields::Unnamed(unnamed) => unnamed.unnamed.is_empty(),
816        } {
817            bail!("Enumeration variants may not have fields");
818        }
819
820        let discriminant_expr = variant_attr(&attrs)?
821            .or(discriminant.map(|(_, expr)| expr))
822            .ok_or_else(|| {
823                err!(
824                    "Enumeration variants must have a discriminant or a #[bilrost(..)] attribute \
825                    with a constant value"
826                )
827            })?;
828        variants.push(EnumVariant {
829            variant_ident,
830            discriminant_expr,
831        });
832    }
833    let zero_variant_ident = variants
834        .iter()
835        .find(|variant| is_zero_discriminant(&variant.discriminant_expr))
836        .map(|variant| &variant.variant_ident);
837
838    let Some(EnumVariant {
839        variant_ident: first_variant,
840        ..
841    }) = variants.first()
842    else {
843        bail!("Enumerations must have at least one variant");
844    };
845
846    let variant_idents: Vec<_> = variants
847        .iter()
848        .map(|variant| &variant.variant_ident)
849        .collect();
850    let discriminant_exprs: Vec<_> = variants
851        .iter()
852        .map(|variant| &variant.discriminant_expr)
853        .collect();
854
855    // When the type has a zero-valued variant, we implement `EmptyState`. When it doesn't, we
856    // need at least some way to create a value to be overwritten, so we impl `ForOverwrite`
857    // directly with an arbitrary variant.
858    let creation_impl = if let Some(zero) = &zero_variant_ident {
859        quote! {
860            impl #impl_generics #crate_::encoding::ForOverwrite<(), #ident #ty_generics> for ()
861            #where_clause {
862                #[inline]
863                fn for_overwrite() -> #ident #ty_generics {
864                    #ident::#zero { }
865                }
866            }
867
868            impl #impl_generics #crate_::encoding::EmptyState<(), #ident #ty_generics> for ()
869            #where_clause {
870                #[inline]
871                fn is_empty(val: &#ident #ty_generics) -> bool {
872                    matches!(val, #ident::#zero { })
873                }
874
875                #[inline]
876                fn clear(val: &mut #ident #ty_generics) {
877                    *val = #ident::#zero { };
878                }
879            }
880        }
881    } else {
882        quote! {
883            impl #impl_generics #crate_::encoding::ForOverwrite<(), #ident #ty_generics> for ()
884            #where_clause {
885                fn for_overwrite() -> #ident #ty_generics {
886                    #ident::#first_variant { }
887                }
888            }
889        }
890    };
891
892    let expanded = quote! {
893        impl #impl_generics #crate_::Enumeration for #ident #ty_generics #where_clause {
894            #[inline]
895            fn to_number(&self) -> u32 {
896                match self {
897                    #(#ident::#variant_idents { } => #discriminant_exprs,)*
898                }
899            }
900
901            #[inline]
902            fn try_from_number(value: u32) -> ::core::result::Result<#ident, u32> {
903                #[forbid(unreachable_patterns)]
904                ::core::result::Result::Ok(match value {
905                    #(#discriminant_exprs => #ident::#variant_idents { },)*
906                    _ => ::core::result::Result::Err(value)?,
907                })
908            }
909
910            #[inline]
911            fn is_valid(__n: u32) -> bool {
912                #[forbid(unreachable_patterns)]
913                match __n {
914                    #(#discriminant_exprs => true,)*
915                    _ => false,
916                }
917            }
918        }
919
920        #creation_impl
921
922        impl #unborrowed_generics
923        #crate_::encoding::Wiretyped<
924            #crate_::encoding::GeneralGeneric<__G>,
925            #ident #ty_generics
926        > for () #where_clause {
927            const WIRE_TYPE: #crate_::encoding::WireType = #crate_::encoding::WireType::Varint;
928        }
929
930        impl #unborrowed_generics
931        #crate_::encoding::ValueEncoder<
932            #crate_::encoding::GeneralGeneric<__G>,
933            #ident #ty_generics
934        > for () #where_clause {
935            #[inline]
936            fn encode_value<__B: #crate_::bytes::BufMut + ?Sized>(
937                value: &#ident #ty_generics,
938                buf: &mut __B,
939            ) {
940                #crate_::encoding::encode_varint(
941                    #crate_::Enumeration::to_number(value) as u64,
942                    buf,
943                );
944            }
945
946            #[inline]
947            fn prepend_value<__B: #crate_::buf::ReverseBuf + ?Sized>(
948                value: &#ident #ty_generics,
949                buf: &mut __B,
950            ) {
951                #crate_::encoding::prepend_varint(
952                    #crate_::Enumeration::to_number(value) as u64,
953                    buf,
954                );
955            }
956
957            #[inline]
958            fn value_encoded_len(value: &#ident #ty_generics) -> usize {
959                #crate_::encoding::encoded_len_varint(
960                    #crate_::encoding::Enumeration::to_number(value) as u64
961                )
962            }
963        }
964
965        impl #unborrowed_generics
966        #crate_::encoding::ValueDecoder<
967            #crate_::encoding::GeneralGeneric<__G>,
968            #ident #ty_generics
969        > for () #where_clause {
970            #[inline]
971            fn decode_value<__B: #crate_::bytes::Buf + ?Sized>(
972                value: &mut #ident #ty_generics,
973                mut buf: #crate_::encoding::Capped<__B>,
974                _ctx: #crate_::encoding::DecodeContext,
975            ) -> Result<(), #crate_::DecodeError> {
976                let decoded = buf.decode_varint()?;
977                let ::core::result::Result::Ok(in_range) = u32::try_from(decoded) else {
978                    return ::core::result::Result::Err(
979                        #crate_::DecodeErrorKind::OutOfDomainValue.into()
980                    );
981                };
982                let ::core::result::Result::Ok(typed) =
983                    <#ident #ty_generics as #crate_::Enumeration>::try_from_number(in_range) else {
984                    return ::core::result::Result::Err(
985                        #crate_::DecodeErrorKind::OutOfDomainValue.into()
986                    );
987                };
988                *value = typed;
989                ::core::result::Result::Ok(())
990            }
991        }
992
993        impl #unborrowed_generics
994        #crate_::encoding::DistinguishedValueDecoder<
995            #crate_::encoding::GeneralGeneric<__G>,
996            #ident #ty_generics
997        > for () #where_clause {
998            const CHECKS_EMPTY: bool = false;
999
1000            #[inline]
1001            fn decode_value_distinguished<const ALLOW_EMPTY: bool>(
1002                value: &mut #ident #ty_generics,
1003                buf: #crate_::encoding::Capped<impl #crate_::bytes::Buf + ?Sized>,
1004                ctx: #crate_::encoding::RestrictedDecodeContext,
1005            ) -> Result<#crate_::Canonicity, #crate_::DecodeError> {
1006                <() as #crate_::encoding::ValueDecoder<
1007                    #crate_::encoding::GeneralGeneric<__G>, #ident #ty_generics
1008                >>::decode_value(
1009                    value,
1010                    buf,
1011                    ctx.into_inner(),
1012                )?;
1013                ::core::result::Result::Ok(#crate_::Canonicity::Canonical)
1014            }
1015        }
1016
1017        impl #borrow_generics
1018        #crate_::encoding::ValueBorrowDecoder<
1019            '__a,
1020            #crate_::encoding::GeneralGeneric<__G>,
1021            #ident #ty_generics
1022        > for () #where_clause {
1023            #[inline(always)]
1024            fn borrow_decode_value(
1025                value: &mut #ident #ty_generics,
1026                mut buf: #crate_::encoding::Capped<&'__a [u8]>,
1027                ctx: #crate_::encoding::DecodeContext,
1028            ) -> Result<(), #crate_::DecodeError> {
1029                <() as #crate_::encoding::ValueDecoder<
1030                    #crate_::encoding::GeneralGeneric<__G>, #ident #ty_generics
1031                >>::decode_value(
1032                    value,
1033                    buf,
1034                    ctx,
1035                )
1036            }
1037        }
1038
1039        impl #borrow_generics
1040        #crate_::encoding::DistinguishedValueBorrowDecoder<
1041            '__a,
1042            #crate_::encoding::GeneralGeneric<__G>,
1043            #ident #ty_generics
1044        > for () #where_clause {
1045            const CHECKS_EMPTY: bool = false;
1046
1047            #[inline(always)]
1048            fn borrow_decode_value_distinguished<const ALLOW_EMPTY: bool>(
1049                value: &mut #ident #ty_generics,
1050                buf: #crate_::encoding::Capped<&'__a [u8]>,
1051                ctx: #crate_::encoding::RestrictedDecodeContext,
1052            ) -> Result<#crate_::Canonicity, #crate_::DecodeError> {
1053                <() as #crate_::encoding::ValueDecoder<
1054                    #crate_::encoding::GeneralGeneric<__G>, #ident #ty_generics
1055                >>::decode_value(
1056                    value,
1057                    buf,
1058                    ctx.into_inner(),
1059                )?;
1060                ::core::result::Result::Ok(#crate_::Canonicity::Canonical)
1061            }
1062        }
1063    };
1064
1065    Ok(expanded)
1066}
1067
1068#[proc_macro_derive(Enumeration, attributes(bilrost))]
1069pub fn enumeration(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
1070    try_enumeration(input.into()).unwrap().into()
1071}
1072
1073/// Detects whether the given expression, denoting the discriminant of an enumeration variant, is
1074/// definitely zero.
1075fn is_zero_discriminant(expr: &Expr) -> bool {
1076    expr.to_token_stream().to_string() == "0"
1077}
1078
1079/// Get the numeric variant value for an enumeration from attrs.
1080fn variant_attr(attrs: &Vec<Attribute>) -> Result<Option<Expr>, Error> {
1081    let mut result: Option<Expr> = None;
1082    for attr in attrs {
1083        if attr.meta.path().is_ident("bilrost") {
1084            // attribute values for enumerations don't have to be exactly numeric literals, but they
1085            // will need to be used both as a literal-equivalent u32 value and as the match pattern
1086            // for the variant's corresponding value.
1087            let Some(expr) = match &attr.meta {
1088                Meta::List(list) => parse2::<Expr>(list.tokens.clone()).ok(),
1089                Meta::NameValue(name_value) => Some(name_value.value.clone()),
1090                _ => None,
1091            }
1092            .filter(|expr| {
1093                // it's a valid expression; also make sure that it parses successfully as a
1094                // single-variant pattern
1095                syn::parse::Parser::parse2(Pat::parse_single, expr.to_token_stream()).is_ok()
1096            }) else {
1097                bail!(
1098                    "attribute on enumeration variant must be valid as both an expression and a \
1099                    match pattern for u32"
1100                );
1101            };
1102
1103            set_option(
1104                &mut result,
1105                expr,
1106                "duplicate value attributes on enumeration variant",
1107            )?;
1108        }
1109    }
1110    Ok(result)
1111}
1112
1113struct PreprocessedOneof<'a> {
1114    ident: Ident,
1115    impl_generics: &'a Generics,
1116    ty_generics: TypeGenerics<'a>,
1117    where_clause: Option<&'a WhereClause>,
1118    variants: Vec<OneofVariant>,
1119    distinguished: bool,
1120    borrow_only: bool,
1121    empty_variant: Option<Ident>,
1122}
1123
1124fn preprocess_oneof(input: &DeriveInput) -> Result<PreprocessedOneof<'_>, Error> {
1125    let ident = input.ident.clone();
1126
1127    let input_variants = match &input.data {
1128        Data::Enum(enum_) => enum_.variants.clone(),
1129        Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
1130        Data::Union(..) => bail!("Oneof can not be derived for a union"),
1131    };
1132
1133    let mut reserved_tags = None;
1134    let mut unknown_attrs = Vec::new();
1135    let mut distinguished = false;
1136    let mut borrow_only = false;
1137    for attr in bilrost_attrs(&input.attrs)? {
1138        if let Some(tags) = tag_list_attr(&attr, "reserved_tags", None)? {
1139            set_option(
1140                &mut reserved_tags,
1141                tags,
1142                "duplicate reserved_tags attributes",
1143            )?
1144        } else if word_attr(&attr, "distinguished") {
1145            set_bool(&mut distinguished, "duplicated distinguished attributes")?;
1146        } else if word_attr(&attr, "borrowed_only") {
1147            set_bool(&mut borrow_only, "duplicated borrowed_only attributes")?;
1148        } else {
1149            unknown_attrs.push(attr);
1150        }
1151    }
1152
1153    if !unknown_attrs.is_empty() {
1154        bail!(
1155            "unknown attribute(s) for oneof-message: {}",
1156            quote!(#(#unknown_attrs),*)
1157        )
1158    }
1159
1160    // Oneof enums have either zero or one unit variant. If there is no such variant, the Oneof
1161    // trait is implemented on `Option<T>`, and `None` stands in for no fields being set. If there
1162    // is such a variant, it becomes the empty state for the type and stands in for no fields being
1163    // set.
1164    let mut empty_variant: Option<Ident> = None;
1165    let mut variants = vec![];
1166    // Map the variants into 'fields'.
1167    for variant in input_variants {
1168        let variant_ident = variant.ident.clone();
1169        match OneofVariant::new(variant)? {
1170            Some(variant) => {
1171                variants.push(variant);
1172            }
1173            None => {
1174                set_option(
1175                    &mut empty_variant,
1176                    variant_ident,
1177                    "Oneofs may have at most one empty enum variant. To use multiple \
1178                    variants without fields, the non-empty variants can be marked as values with \
1179                    the 'message' attribute and the empty variant can be either left un-marked or \
1180                    explicitly marked with the 'empty' attribute.\n\nThe conflicting variants were",
1181                )?;
1182            }
1183        }
1184    }
1185
1186    if distinguished && variants.iter().any(OneofVariant::has_ignored_fields) {
1187        bail!("Oneofs with ignored fields cannot be distinguished");
1188    }
1189
1190    // Index all fields by their tag(s) and check them against the forbidden tag ranges
1191    let all_tags: BTreeMap<u32, &Ident> = variants
1192        .iter()
1193        .map(|variant| (variant.tag(), variant.ident()))
1194        .collect();
1195    for reserved_range in reserved_tags.unwrap_or_default().iter_tag_ranges() {
1196        if let Some((forbidden_tag, variant_ident)) = all_tags.range(reserved_range).next() {
1197            bail!("oneof {ident} variant {variant_ident} has reserved tag {forbidden_tag}");
1198        }
1199    }
1200
1201    let generics = &input.generics;
1202    let (_, ty_generics, where_clause) = generics.split_for_impl();
1203
1204    Ok(PreprocessedOneof {
1205        ident,
1206        impl_generics: generics,
1207        ty_generics,
1208        where_clause,
1209        variants,
1210        distinguished,
1211        borrow_only,
1212        empty_variant,
1213    })
1214}
1215
1216fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
1217    let crate_ = crate_name();
1218    let input: DeriveInput = parse2(input)?;
1219
1220    let PreprocessedOneof {
1221        ident,
1222        impl_generics,
1223        ty_generics,
1224        where_clause,
1225        variants,
1226        distinguished,
1227        borrow_only,
1228        empty_variant,
1229    } = preprocess_oneof(&input)?;
1230
1231    let borrow_generics = prepend_to_generics(impl_generics, quote!('__a));
1232
1233    let encoder_where_clause = append_wheres_with_fields(where_clause, None, &variants, Encode);
1234    let owned_decoder_where_clause =
1235        append_wheres_with_fields(where_clause, None, &variants, Decode(Owned, Relaxed));
1236    let borrowed_decoder_where_clause =
1237        append_wheres_with_fields(where_clause, None, &variants, Decode(Borrowed, Relaxed));
1238
1239    let sorted_tags: Vec<u32> = variants
1240        .iter()
1241        .map(|variant| variant.tag())
1242        .sorted_unstable()
1243        .collect();
1244    if let Some((duplicate_tag, _)) = sorted_tags.iter().tuple_windows().find(|(a, b)| a == b) {
1245        bail!(
1246            "invalid oneof {}: multiple variants have tag {}",
1247            ident,
1248            duplicate_tag
1249        );
1250    }
1251
1252    let self_alias = quote!(Self);
1253
1254    let mut encode: Vec<TokenStream> = variants
1255        .iter()
1256        .map(|variant| variant.encode(&self_alias))
1257        .collect();
1258
1259    let mut prepend: Vec<TokenStream> = variants
1260        .iter()
1261        .map(|variant| variant.prepend(&self_alias))
1262        .collect();
1263
1264    let mut encoded_len: Vec<TokenStream> = variants
1265        .iter()
1266        .map(|variant| variant.encoded_len(&self_alias))
1267        .collect();
1268
1269    let encoder_trait;
1270    let owned_decoder_trait;
1271    let borrowed_decoder_trait;
1272    let decode_field_self_arg;
1273    let decode_field_return_ty;
1274    let current_tag_ty;
1275    let current_tag: Vec<TokenStream>;
1276    let empty_methods_impl;
1277    let some;
1278
1279    if let Some(empty_ident) = &empty_variant {
1280        encoder_trait = quote!(Oneof);
1281        owned_decoder_trait = quote!(OneofDecoder);
1282        borrowed_decoder_trait = quote!(OneofBorrowDecoder<'__a>);
1283        decode_field_self_arg = Some(quote!(value: &mut Self,));
1284        decode_field_return_ty = quote!(());
1285        some = Some(quote!(::core::option::Option::Some));
1286
1287        current_tag_ty = quote!(::core::option::Option<u32>);
1288        current_tag = variants
1289            .iter()
1290            .map(|variant| {
1291                let tag = variant.tag();
1292                let variant_ident = variant.ident();
1293                quote!(Self::#variant_ident { .. } => ::core::option::Option::Some(#tag))
1294            })
1295            .chain([quote!(Self::#empty_ident => ::core::option::Option::None)])
1296            .collect();
1297        encode.push(quote!(Self::#empty_ident => {}));
1298        prepend.push(quote!(Self::#empty_ident => {}));
1299        encoded_len.push(quote!(Self::#empty_ident => 0));
1300
1301        empty_methods_impl = Some(quote! {
1302            fn empty() -> Self {
1303                Self::#empty_ident
1304            }
1305
1306            fn is_empty(&self) -> bool {
1307                matches!(self, Self::#empty_ident)
1308            }
1309
1310            fn clear(&mut self) {
1311                *self = Self::#empty_ident;
1312            }
1313        });
1314    } else {
1315        encoder_trait = quote!(NonEmptyOneof);
1316        owned_decoder_trait = quote!(NonEmptyOneofDecoder);
1317        borrowed_decoder_trait = quote!(NonEmptyOneofBorrowDecoder<'__a>);
1318        decode_field_self_arg = None;
1319        decode_field_return_ty = quote!(Self);
1320        some = None;
1321
1322        // The oneof enum has no "empty" unit variant, so we implement the "non-empty" trait.
1323        current_tag_ty = quote!(u32);
1324        current_tag = variants
1325            .iter()
1326            .map(|variant| {
1327                let tag = variant.tag();
1328                let variant_ident = variant.ident();
1329                quote!(Self::#variant_ident { .. } => #tag)
1330            })
1331            .collect();
1332
1333        empty_methods_impl = None;
1334    };
1335
1336    let variant_name_arms = variants.iter().map(|variant| {
1337        let tag = variant.tag();
1338        let ident_str = ident.to_string();
1339        let variant_ident_str = variant.ident().to_string();
1340        quote! {
1341            #tag => (#ident_str, #variant_ident_str),
1342        }
1343    });
1344
1345    let decode_arms = |lifetime, mode| {
1346        let ident_str = ident.to_string();
1347        let arms = variants
1348            .iter()
1349            .map(|variant| variant.decode(&self_alias, lifetime, mode));
1350        quote! {
1351            match tag {
1352                #(#arms,)*
1353                _ => unreachable!(
1354                    concat!("invalid ", #ident_str, " tag: {}"), tag,
1355                ),
1356            }
1357        }
1358    };
1359
1360    let [decode_owned, decode_borrowed] = match empty_variant {
1361        None => [decode_arms(Owned, Relaxed), decode_arms(Borrowed, Relaxed)],
1362        Some(ref empty_ident) => [
1363            decode_arms(Owned, Relaxed),
1364            decode_arms(Borrowed, Relaxed),
1365        ]
1366            .map(|decode| quote! {
1367            // Guards against colliding oneof field decoding are only evaluated by the Oneof trait,
1368            // when `oneof_decode_field` is called and the oneof value is already populated.
1369            // Whichever implementer is responsible for the oneof having an empty state is also
1370            // responsible for checking for this conflict and returning an error with the
1371            // appropriate decode error kind and augmenting it with the correct variant name, which
1372            // can be gotten from `oneof_variant_name` via the trait.
1373            //
1374            // In this case we are implementing a Oneof that has an intrinsic empty state via a unit
1375            // variant, so we insert this guard right into our trait impl. The signature for this
1376            // method in the `NonEmptyOneof` trait is slightly different to allow for easier nested
1377            // value implementations, and returns the `Self` type directly in the result instead
1378            // with no guard.
1379            //
1380            // The other purpose this guard serves is to attach an error path detail for the field
1381            // in the oneof when it bubbles back up through this call. For that reason, also
1382            // mentioned elsewhere, we structure most of this code to be pretty much one big
1383            // Result-valued expression to serve this match on the very outside.
1384            match if let Self::#empty_ident = value {
1385                match #decode {
1386                    ::core::result::Result::Ok(decoded) => {
1387                        *value = decoded;
1388                        ::core::result::Result::Ok(())
1389                    }
1390                    ::core::result::Result::Err(error) => ::core::result::Result::Err(error),
1391                }
1392            } else {
1393                ::core::result::Result::Err(#crate_::DecodeError::new(
1394                    if #crate_::encoding::#encoder_trait::oneof_current_tag(value) == #some(tag) {
1395                        #crate_::DecodeErrorKind::UnexpectedlyRepeated
1396                    } else {
1397                        #crate_::DecodeErrorKind::ConflictingFields
1398                    }
1399                ))
1400            } {
1401                ::core::result::Result::Err(mut error) => {
1402                    let (msg, field) =
1403                        <Self as #crate_::encoding::#encoder_trait>::oneof_variant_name(tag);
1404                    error.push(msg, field);
1405                    ::core::result::Result::Err(error)
1406                }
1407                ok => ok,
1408            }
1409        })
1410    };
1411
1412    let impl_owned_decoder = (!borrow_only).then(|| {
1413        quote! {
1414            impl #impl_generics #crate_::encoding::#owned_decoder_trait
1415            for __Self #ty_generics #owned_decoder_where_clause
1416            {
1417                fn oneof_decode_field<__B: #crate_::bytes::Buf + ?Sized>(
1418                    #decode_field_self_arg
1419                    tag: u32,
1420                    wire_type: #crate_::encoding::WireType,
1421                    buf: #crate_::encoding::Capped<__B>,
1422                    ctx: #crate_::encoding::DecodeContext,
1423                ) -> ::core::result::Result<#decode_field_return_ty, #crate_::DecodeError> {
1424                    #decode_owned
1425                }
1426            }
1427        }
1428    });
1429
1430    let impls = quote! {
1431        impl #impl_generics #crate_::encoding::#encoder_trait
1432        for __Self #ty_generics #encoder_where_clause
1433        {
1434            const FIELD_TAGS: &'static [u32] = &[#(#sorted_tags),*];
1435
1436            #empty_methods_impl
1437
1438            fn oneof_encode<__B: #crate_::bytes::BufMut + ?Sized>(
1439                &self,
1440                buf: &mut __B,
1441                tw: &mut #crate_::encoding::TagWriter,
1442            ) {
1443                match self {
1444                    #(#encode,)*
1445                }
1446            }
1447
1448            fn oneof_prepend<__B: #crate_::buf::ReverseBuf + ?Sized>(
1449                &self,
1450                buf: &mut __B,
1451                tw: &mut #crate_::encoding::TagRevWriter,
1452            ) {
1453                match self {
1454                    #(#prepend,)*
1455                }
1456            }
1457
1458            fn oneof_encoded_len(
1459                &self,
1460                tm: &mut impl #crate_::encoding::TagMeasurer,
1461            ) -> usize {
1462                match self {
1463                    #(#encoded_len,)*
1464                }
1465            }
1466
1467            fn oneof_current_tag(&self) -> #current_tag_ty {
1468                match self {
1469                    #(#current_tag,)*
1470                }
1471            }
1472
1473            fn oneof_variant_name(tag: u32) -> (&'static str, &'static str) {
1474                match tag {
1475                    #(#variant_name_arms)*
1476                    _ => ("", ""),
1477                }
1478            }
1479        }
1480
1481        #impl_owned_decoder
1482
1483        impl #borrow_generics #crate_::encoding::#borrowed_decoder_trait
1484        for __Self #ty_generics #borrowed_decoder_where_clause
1485        {
1486            fn oneof_borrow_decode_field(
1487                #decode_field_self_arg
1488                tag: u32,
1489                wire_type: #crate_::encoding::WireType,
1490                buf: #crate_::encoding::Capped<&'__a [u8]>,
1491                ctx: #crate_::encoding::DecodeContext,
1492            ) -> ::core::result::Result<#decode_field_return_ty, #crate_::DecodeError> {
1493                #decode_borrowed
1494            }
1495        }
1496    };
1497
1498    let distinguished_impls = distinguished.then(|| {
1499        let owned_decoder_trait;
1500        let borrowed_decoder_trait;
1501        let relaxed_oneof_trait; // we must reference the parent trait for `oneof_current_tag`
1502        let decode_field_self_arg;
1503        let decode_field_return_ty;
1504        let some; // oneofs that have empty states return Option<u32> from `oneof_current_tag`
1505        let owned_decoder_where_clause;
1506        let borrowed_decoder_where_clause;
1507        if empty_variant.is_some() {
1508            owned_decoder_trait = quote!(DistinguishedOneofDecoder);
1509            borrowed_decoder_trait = quote!(DistinguishedOneofBorrowDecoder<'__a>);
1510            relaxed_oneof_trait = quote!(Oneof);
1511            decode_field_self_arg = Some(quote!(value: &mut Self,));
1512            decode_field_return_ty = quote!(#crate_::Canonicity);
1513            some = Some(quote!(::core::option::Option::Some));
1514            [owned_decoder_where_clause, borrowed_decoder_where_clause] =
1515                [Owned, Borrowed].map(|lifetime| {
1516                    append_wheres_with_fields(
1517                        where_clause,
1518                        [quote!(Self: #crate_::encoding::Oneof)],
1519                        &variants,
1520                        Decode(lifetime, Distinguished),
1521                    )
1522                });
1523        } else {
1524            owned_decoder_trait = quote!(NonEmptyDistinguishedOneofDecoder);
1525            borrowed_decoder_trait = quote!(NonEmptyDistinguishedOneofBorrowDecoder<'__a>);
1526            relaxed_oneof_trait = quote!(NonEmptyOneof);
1527            decode_field_self_arg = None;
1528            decode_field_return_ty = quote!((Self, #crate_::Canonicity));
1529            some = None;
1530            [owned_decoder_where_clause, borrowed_decoder_where_clause] =
1531                [Owned, Borrowed].map(|lifetime| {
1532                    append_wheres_with_fields(
1533                        where_clause,
1534                        None,
1535                        &variants,
1536                        Decode(lifetime, Distinguished),
1537                    )
1538                });
1539        };
1540
1541        let [decode_owned, decode_borrowed] = match empty_variant {
1542            None => [
1543                decode_arms(Owned, Distinguished),
1544                decode_arms(Borrowed, Distinguished),
1545            ],
1546            Some(empty_ident) => [
1547                decode_arms(Owned, Distinguished),
1548                decode_arms(Borrowed, Distinguished),
1549            ]
1550            .map(|decode| {
1551                quote! {
1552                    // See the note above for details about the colliding field guard.
1553                    match if let Self::#empty_ident = value {
1554                        match #decode {
1555                            ::core::result::Result::Ok((decoded, canon)) => {
1556                                *value = decoded;
1557                                ::core::result::Result::Ok(canon)
1558                            }
1559                            ::core::result::Result::Err(error) => {
1560                                ::core::result::Result::Err(error)
1561                            },
1562                        }
1563                    } else {
1564                        ::core::result::Result::Err(#crate_::DecodeError::new(
1565                            if #crate_::encoding::#relaxed_oneof_trait::oneof_current_tag(value)
1566                                == #some(tag)
1567                            {
1568                                #crate_::DecodeErrorKind::UnexpectedlyRepeated
1569                            } else {
1570                                #crate_::DecodeErrorKind::ConflictingFields
1571                            }
1572                        ))
1573                    } {
1574                        ::core::result::Result::Err(mut error) => {
1575                            let (msg, field) =
1576                                <Self as #crate_::encoding::#relaxed_oneof_trait>::
1577                                    oneof_variant_name(tag);
1578                            error.push(msg, field);
1579                            ::core::result::Result::Err(error)
1580                        }
1581                        ok => ok,
1582                    }
1583                }
1584            }),
1585        };
1586
1587        let impl_owned_decoder = (!borrow_only).then(|| {
1588            quote! {
1589                impl #impl_generics #crate_::encoding::#owned_decoder_trait
1590                for __Self #ty_generics #owned_decoder_where_clause
1591                {
1592                    fn oneof_decode_field_distinguished<__B: #crate_::bytes::Buf + ?Sized>(
1593                        #decode_field_self_arg
1594                        tag: u32,
1595                        wire_type: #crate_::encoding::WireType,
1596                        buf: #crate_::encoding::Capped<__B>,
1597                        ctx: #crate_::encoding::RestrictedDecodeContext,
1598                    ) -> ::core::result::Result<#decode_field_return_ty, #crate_::DecodeError> {
1599                        #decode_owned
1600                    }
1601                }
1602            }
1603        });
1604
1605        quote! {
1606            #impl_owned_decoder
1607
1608            impl #borrow_generics #crate_::encoding::#borrowed_decoder_trait
1609            for __Self #ty_generics #borrowed_decoder_where_clause
1610            {
1611                fn oneof_borrow_decode_field_distinguished(
1612                    #decode_field_self_arg
1613                    tag: u32,
1614                    wire_type: #crate_::encoding::WireType,
1615                    buf: #crate_::encoding::Capped<&'__a [u8]>,
1616                    ctx: #crate_::encoding::RestrictedDecodeContext,
1617                ) -> ::core::result::Result<#decode_field_return_ty, #crate_::DecodeError> {
1618                    #decode_borrowed
1619                }
1620            }
1621        }
1622    });
1623
1624    let aliases = encoder_alias_header();
1625    Ok(quote! {
1626        const _: () = {
1627            use #ident as __Self;
1628
1629            const _: () = {
1630                #aliases
1631
1632                #impls
1633
1634                #distinguished_impls
1635            };
1636        };
1637    })
1638}
1639
1640#[proc_macro_derive(Oneof, attributes(bilrost))]
1641pub fn oneof(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
1642    try_oneof(input.into()).unwrap().into()
1643}
1644
1645#[cfg(test)]
1646mod test {
1647    use crate::{try_enumeration, try_message, try_oneof};
1648    use alloc::format;
1649    use alloc::string::ToString;
1650    use quote::quote;
1651
1652    #[test]
1653    fn test_rejects_colliding_message_fields() {
1654        let output = try_message(quote! {
1655            struct Invalid {
1656                #[bilrost(tag = "1")]
1657                a: bool,
1658                #[bilrost(oneof(4, 5, 1))]
1659                b: Option<super::Whatever>,
1660            }
1661        });
1662        assert_eq!(
1663            output.expect_err("duplicate tags not detected").to_string(),
1664            "multiple fields have tag 1"
1665        );
1666
1667        let output = try_message(quote! {
1668            struct Invalid {
1669                #[bilrost(tag = "2")]
1670                a: bool,
1671                #[bilrost(oneof(1-3))]
1672                b: Option<super::Whatever>,
1673            }
1674        });
1675        assert_eq!(
1676            output.expect_err("duplicate tags not detected").to_string(),
1677            "multiple fields have tag 2"
1678        );
1679
1680        let output = try_message(quote! {
1681            struct Invalid {
1682                #[bilrost(tag = "10")]
1683                a: bool,
1684                #[bilrost(oneof = "5-10")]
1685                b: Option<super::Whatever>,
1686            }
1687        });
1688        assert_eq!(
1689            output.expect_err("duplicate tags not detected").to_string(),
1690            "multiple fields have tag 10"
1691        );
1692
1693        // Tags that don't collide with ranges are fine
1694        _ = try_message(quote! {
1695            struct Valid {
1696                #[bilrost(tag = "4")]
1697                a: bool,
1698                #[bilrost(oneof(5-10, 1-3))]
1699                b: Option<super::Whatever>,
1700            }
1701        })
1702        .unwrap();
1703    }
1704
1705    #[test]
1706    fn test_rejects_reserved_message_fields() {
1707        let output = try_message(quote! {
1708            #[bilrost(reserved_tags(1, 100))]
1709            struct Invalid {
1710                #[bilrost(tag = "1")]
1711                a: bool,
1712                #[bilrost(oneof(3-5))]
1713                b: Option<super::Whatever>,
1714            }
1715        });
1716        assert_eq!(
1717            output.expect_err("reserved tags not detected").to_string(),
1718            "field a has reserved tag 1"
1719        );
1720
1721        let output = try_message(quote! {
1722            #[bilrost(reserved_tags(4, 55))]
1723            struct Invalid {
1724                #[bilrost(tag = "1")]
1725                a: bool,
1726                #[bilrost(oneof(3-5))]
1727                b: Option<super::Whatever>,
1728            }
1729        });
1730        assert_eq!(
1731            output.expect_err("reserved tags not detected").to_string(),
1732            "field b has reserved tag 4"
1733        );
1734
1735        let output = try_message(quote! {
1736            #[bilrost(reserved_tags(5-10, 55))]
1737            struct Invalid {
1738                #[bilrost(tag = "1")]
1739                a: bool,
1740                #[bilrost(oneof(3-5))]
1741                b: Option<super::Whatever>,
1742            }
1743        });
1744        assert_eq!(
1745            output.expect_err("reserved tags not detected").to_string(),
1746            "field b has reserved tag 5"
1747        );
1748
1749        let output = try_message(quote! {
1750            #[bilrost(reserved_tags(..=3, 55))]
1751            struct Invalid {
1752                #[bilrost(tag = "999")]
1753                a: bool,
1754                #[bilrost(oneof(3-5))]
1755                b: Option<super::Whatever>,
1756            }
1757        });
1758        assert_eq!(
1759            output.expect_err("reserved tags not detected").to_string(),
1760            "field b has reserved tag 3"
1761        );
1762
1763        let output = try_message(quote! {
1764            #[bilrost(reserved_tags(0, 5..))]
1765            struct Invalid {
1766                #[bilrost(tag = "1")]
1767                a: bool,
1768                #[bilrost(oneof(3-5))]
1769                b: Option<super::Whatever>,
1770            }
1771        });
1772        assert_eq!(
1773            output.expect_err("reserved tags not detected").to_string(),
1774            "field b has reserved tag 5"
1775        );
1776    }
1777
1778    #[test]
1779    fn test_rejects_reserved_oneof_fields() {
1780        let output = try_message(quote! {
1781            #[bilrost(reserved_tags(1, 100))]
1782            enum Invalid {
1783                #[bilrost(tag = "1")]
1784                A(bool),
1785                #[bilrost(5)]
1786                B(super::Whatever),
1787            }
1788        });
1789        assert_eq!(
1790            output.expect_err("reserved tags not detected").to_string(),
1791            "oneof Invalid variant A has reserved tag 1"
1792        );
1793
1794        let output = try_message(quote! {
1795            #[bilrost(reserved_tags(5, 55))]
1796            enum Invalid {
1797                #[bilrost(tag = "1")]
1798                A(bool),
1799                #[bilrost(5)]
1800                B(super::Whatever),
1801            }
1802        });
1803        assert_eq!(
1804            output.expect_err("reserved tags not detected").to_string(),
1805            "oneof Invalid variant B has reserved tag 5"
1806        );
1807
1808        let output = try_message(quote! {
1809            #[bilrost(reserved_tags(5-10, 55))]
1810            enum Invalid {
1811                #[bilrost(tag = "1")]
1812                A(bool),
1813                #[bilrost(5)]
1814                B(super::Whatever),
1815            }
1816        });
1817        assert_eq!(
1818            output.expect_err("reserved tags not detected").to_string(),
1819            "oneof Invalid variant B has reserved tag 5"
1820        );
1821
1822        let output = try_message(quote! {
1823            #[bilrost(reserved_tags(..=3, 55))]
1824            enum Invalid {
1825                #[bilrost(tag = "1")]
1826                A(bool),
1827                #[bilrost(5)]
1828                B(super::Whatever),
1829            }
1830        });
1831        assert_eq!(
1832            output.expect_err("reserved tags not detected").to_string(),
1833            "oneof Invalid variant A has reserved tag 1"
1834        );
1835
1836        let output = try_message(quote! {
1837            #[bilrost(reserved_tags(0, 5..))]
1838            enum Invalid {
1839                #[bilrost(tag = "1")]
1840                A(bool),
1841                #[bilrost(5)]
1842                B(super::Whatever),
1843            }
1844        });
1845        assert_eq!(
1846            output.expect_err("reserved tags not detected").to_string(),
1847            "oneof Invalid variant B has reserved tag 5"
1848        );
1849    }
1850
1851    #[test]
1852    fn test_rejects_oversize_oneof_tag_ranges() {
1853        let output = try_message(quote! {
1854            struct Invalid {
1855                #[bilrost(oneof(1-100))]
1856                a: SomeOneof,
1857            }
1858        });
1859        assert_eq!(
1860            format!(
1861                "{:#}",
1862                output.expect_err("oversized tag range not detected")
1863            ),
1864            "invalid field a: too-large tag range 1-100; use smaller ranges"
1865        );
1866    }
1867
1868    #[test]
1869    fn test_accepts_tag_ranges() {
1870        try_message(quote! {
1871            #[bilrost(reserved_tags(1, 2, 3))]
1872            struct Valid {
1873                #[bilrost(4)]
1874                x: String,
1875            }
1876        })
1877        .unwrap();
1878
1879        try_message(quote! {
1880            #[bilrost(reserved_tags(1-3, 8-100))]
1881            struct Valid {
1882                #[bilrost(4)]
1883                x: String,
1884            }
1885        })
1886        .unwrap();
1887
1888        try_message(quote! {
1889            #[bilrost(reserved_tags(..=3, 8..))]
1890            struct Valid {
1891                #[bilrost(4)]
1892                x: String,
1893            }
1894        })
1895        .unwrap();
1896    }
1897
1898    #[test]
1899    fn test_rejects_colliding_tag_ranges() {
1900        let output = try_message(quote! {
1901            #[bilrost(reserved_tags(10, 15, 10))]
1902            struct Invalid;
1903        });
1904        assert_eq!(
1905            format!(
1906                "{:#}",
1907                output.expect_err("colliding reserved tag ranges not detected")
1908            ),
1909            "tag 10 is duplicated in tag list"
1910        );
1911
1912        let output = try_message(quote! {
1913            #[bilrost(reserved_tags(1-100, 55))]
1914            struct Invalid;
1915        });
1916        assert_eq!(
1917            format!(
1918                "{:#}",
1919                output.expect_err("colliding reserved tag ranges not detected")
1920            ),
1921            "tag 55 is duplicated in tag list"
1922        );
1923
1924        let output = try_message(quote! {
1925            #[bilrost(reserved_tags(1-100, 50-200))]
1926            struct Invalid;
1927        });
1928        assert_eq!(
1929            format!(
1930                "{:#}",
1931                output.expect_err("colliding reserved tag ranges not detected")
1932            ),
1933            "tag 50 is duplicated in tag list"
1934        );
1935
1936        let output = try_message(quote! {
1937            #[bilrost(reserved_tags(..=100, 6-10, 2))]
1938            struct Invalid;
1939        });
1940        assert_eq!(
1941            format!(
1942                "{:#}",
1943                output.expect_err("colliding reserved tag ranges not detected")
1944            ),
1945            "tag 2 is duplicated in tag list"
1946        );
1947
1948        let output = try_message(quote! {
1949            #[bilrost(reserved_tags(100.., 60, 9999))]
1950            struct Invalid;
1951        });
1952        assert_eq!(
1953            format!(
1954                "{:#}",
1955                output.expect_err("colliding reserved tag ranges not detected")
1956            ),
1957            "tag 9999 is duplicated in tag list"
1958        );
1959    }
1960
1961    #[test]
1962    fn test_rejects_colliding_oneof_variants() {
1963        let output = try_oneof(quote! {
1964            pub enum Invalid {
1965                #[bilrost(tag = "1")]
1966                A(bool),
1967                #[bilrost(tag = "1")]
1968                B(bool),
1969            }
1970        });
1971        assert_eq!(
1972            output
1973                .expect_err("conflicting variant tags not detected")
1974                .to_string(),
1975            "invalid oneof Invalid: multiple variants have tag 1"
1976        );
1977    }
1978
1979    #[test]
1980    fn test_basic_message() {
1981        _ = try_message(quote! {
1982            pub struct Struct {
1983                #[bilrost(3)]
1984                pub fields: BTreeMap<String, i64>,
1985                #[bilrost(0)]
1986                pub foo: String,
1987                #[bilrost(1)]
1988                pub bar: i64,
1989                #[bilrost(2)]
1990                pub baz: bool,
1991            }
1992        })
1993        .unwrap();
1994    }
1995
1996    #[test]
1997    fn test_attribute_forms_are_equivalent() {
1998        let one = try_message(quote! {
1999            struct A (
2000                #[bilrost(tag = "0")] bool,
2001                #[bilrost(oneof = "2, 3")] B,
2002                #[bilrost(tag = "4")] u32,
2003                #[bilrost(tag = "5", encoding = "::custom<Z>")] String,
2004                #[bilrost(tag = "1000")] i64,
2005                #[bilrost(tag = "1001")] bool,
2006            );
2007        })
2008        .unwrap()
2009        .to_string();
2010        let two = try_message(quote! {
2011            struct A (
2012                bool,
2013                #[bilrost(oneof = "2, 3")] B,
2014                #[bilrost(4)] u32,
2015                #[bilrost(encoding(::custom< Z >))] String,
2016                #[bilrost(tag = 1000)] i64,
2017                bool,
2018            );
2019        })
2020        .unwrap()
2021        .to_string();
2022        let three = try_message(quote! {
2023            struct A (
2024                #[bilrost(tag(0))] bool,
2025                #[bilrost(oneof(2, 3))] B,
2026                u32,
2027                #[bilrost(encoding = "::custom <Z>")] String,
2028                #[bilrost(tag(1000))] i64,
2029                bool,
2030            );
2031        })
2032        .unwrap()
2033        .to_string();
2034        let four = try_message(quote! {
2035            struct A (
2036                #[bilrost(0)] bool,
2037                #[bilrost(oneof(2, 3))] B,
2038                u32,
2039                #[bilrost(encoding(::custom<Z>))] String,
2040                #[bilrost(1000)] i64,
2041                #[bilrost()] bool,
2042            );
2043        })
2044        .unwrap()
2045        .to_string();
2046        let minimal = try_message(quote! {
2047            struct A (
2048                bool,
2049                #[bilrost(oneof(2, 3))] B,
2050                u32,
2051                #[bilrost(encoding(::custom<Z>))] String,
2052                #[bilrost(1000)] i64,
2053                bool,
2054            );
2055        })
2056        .unwrap()
2057        .to_string();
2058        assert_eq!(one, two);
2059        assert_eq!(one, three);
2060        assert_eq!(one, four);
2061        assert_eq!(one, minimal);
2062    }
2063
2064    #[test]
2065    fn test_tuple_message() {
2066        _ = try_message(quote! {
2067            struct Tuple(
2068                #[bilrost(5)] bool,
2069                #[bilrost(0)] String,
2070                i64,
2071            );
2072        })
2073        .unwrap();
2074    }
2075
2076    #[test]
2077    fn test_overlapping_message() {
2078        _ = try_message(quote! {
2079            struct Struct {
2080                #[bilrost(0)]
2081                zero: bool,
2082                #[bilrost(oneof(1, 10, 20))]
2083                a: Option<A>,
2084                #[bilrost(4)]
2085                four: bool,
2086                #[bilrost(5)]
2087                five: bool,
2088                #[bilrost(oneof(9, 11))]
2089                b: Option<B>,
2090                twelve: bool, // implicitly tagged 12
2091                #[bilrost(oneof(13, 16, 22))]
2092                c: Option<C>,
2093                #[bilrost(14)]
2094                fourteen: bool,
2095                fifteen: bool, // implicitly tagged 15
2096                #[bilrost(17)]
2097                seventeen: bool,
2098                #[bilrost(oneof(18, 19))]
2099                d: Option<D>,
2100                #[bilrost(21)]
2101                twentyone: bool,
2102                #[bilrost(50)]
2103                fifty: bool,
2104            }
2105        })
2106        .unwrap();
2107    }
2108
2109    #[test]
2110    fn test_rejects_conflicting_empty_oneof_variants() {
2111        let output = try_oneof(quote!(
2112            enum AB {
2113                Empty,
2114                AlsoEmpty,
2115                #[bilrost(1)]
2116                A(bool),
2117                #[bilrost(2)]
2118                B(bool),
2119            }
2120        ));
2121        assert_eq!(
2122            output
2123                .expect_err("conflicting empty variants not detected")
2124                .to_string(),
2125            "Oneofs may have at most one empty enum variant. To use multiple variants without \
2126            fields, the non-empty variants can be marked as values with the 'message' attribute \
2127            and the empty variant can be either left un-marked or explicitly marked with the \
2128            'empty' attribute.\n\nThe conflicting variants were: Ident(Empty) and Ident(AlsoEmpty)"
2129        );
2130    }
2131
2132    #[test]
2133    fn test_rejects_meaningless_empty_variant_attrs() {
2134        let output = try_oneof(quote!(
2135            enum AB {
2136                #[bilrost(empty, anything_else)]
2137                Empty,
2138                #[bilrost(1)]
2139                A(bool),
2140                #[bilrost(2)]
2141                B(bool),
2142            }
2143        ));
2144        assert_eq!(
2145            output
2146                .expect_err("unknown attrs on empty variant not detected")
2147                .to_string(),
2148            "the 'empty' attribute is combined with other attributes on variant Empty, but it must \
2149            always be alone"
2150        );
2151    }
2152
2153    #[test]
2154    fn test_rejects_meaningless_empty_value_variants() {
2155        let output = try_oneof(quote!(
2156            enum AB {
2157                #[bilrost(encoding(X))]
2158                Empty,
2159                #[bilrost(1)]
2160                A(bool),
2161                #[bilrost(2)]
2162                B(bool),
2163            }
2164        ));
2165        assert_eq!(
2166            output
2167                .expect_err("tagless unit variant not detected")
2168                .to_string(),
2169            "missing tag attribute on value variant Empty"
2170        );
2171        let output = try_oneof(quote!(
2172            enum AB {
2173                #[bilrost(tag(0))]
2174                Empty,
2175                #[bilrost(1)]
2176                A(bool),
2177                #[bilrost(2)]
2178                B(bool),
2179            }
2180        ));
2181        assert_eq!(
2182            output.expect_err("unit variant not detected").to_string(),
2183            "Oneof value variants must have exactly one field, but variant Empty has no fields"
2184        );
2185        let output = try_oneof(quote!(
2186            enum AB {
2187                #[bilrost(tag(0))]
2188                Empty {},
2189                #[bilrost(1)]
2190                A(bool),
2191                #[bilrost(2)]
2192                B(bool),
2193            }
2194        ));
2195        assert_eq!(
2196            output
2197                .expect_err("brace unit variant not detected")
2198                .to_string(),
2199            "Oneof value variants must have exactly one field, but variant Empty has 0 fields"
2200        );
2201        let output = try_oneof(quote!(
2202            enum AB {
2203                #[bilrost(tag(0))]
2204                Empty(),
2205                #[bilrost(1)]
2206                A(bool),
2207                #[bilrost(2)]
2208                B(bool),
2209            }
2210        ));
2211        assert_eq!(
2212            output
2213                .expect_err("tuple unit variant not detected")
2214                .to_string(),
2215            "Oneof value variants must have exactly one field, but variant Empty has 0 fields"
2216        );
2217    }
2218
2219    #[test]
2220    fn test_rejects_unnumbered_oneof_variants() {
2221        let output = try_oneof(quote!(
2222            enum AB {
2223                #[bilrost(1)]
2224                A(u32),
2225                #[bilrost(encoding(packed))]
2226                B(Vec<String>),
2227            }
2228        ));
2229        assert_eq!(
2230            output
2231                .expect_err("unnumbered oneof variant not detected")
2232                .to_string(),
2233            "missing tag attribute on value variant B"
2234        );
2235    }
2236
2237    #[test]
2238    fn test_rejects_struct_and_union_enumerations() {
2239        let output = try_enumeration(quote!(
2240            struct X {
2241                x: String,
2242            }
2243        ));
2244        assert_eq!(
2245            output
2246                .expect_err("enumeration of struct not detected")
2247                .to_string(),
2248            "Enumeration can not be derived for a struct"
2249        );
2250        let output = try_enumeration(quote!(
2251            union XY {
2252                x: String,
2253                Y: Vec<u8>,
2254            }
2255        ));
2256        assert_eq!(
2257            output
2258                .expect_err("enumeration of union not detected")
2259                .to_string(),
2260            "Enumeration can not be derived for a union"
2261        );
2262    }
2263
2264    #[test]
2265    fn test_rejects_variant_with_field_in_enumeration() {
2266        let output = try_enumeration(quote!(
2267            enum X {
2268                A = 1,
2269                B(u32) = 2,
2270            }
2271        ));
2272        assert_eq!(
2273            output
2274                .expect_err("variant with field not detected")
2275                .to_string(),
2276            "Enumeration variants may not have fields"
2277        );
2278    }
2279
2280    #[test]
2281    fn test_accepts_mixed_values_in_enumeration() {
2282        _ = try_enumeration(quote!(
2283            enum X<T> {
2284                A = 1,
2285                #[bilrost = 2]
2286                B,
2287                #[bilrost(3)]
2288                C,
2289                #[bilrost(SomeType::<T>::SOME_CONSTANT)]
2290                D,
2291            }
2292        ))
2293        .unwrap();
2294    }
2295
2296    #[test]
2297    fn test_rejects_variant_without_value_in_enumeration() {
2298        let output = try_enumeration(quote!(
2299            enum X<T> {
2300                A = 1,
2301                #[bilrost = 2]
2302                B,
2303                #[bilrost(3)]
2304                C,
2305                #[bilrost(SomeType::<T>::SOME_CONSTANT)]
2306                D,
2307                HasNoValue,
2308            }
2309        ));
2310        assert_eq!(
2311            output
2312                .expect_err("variant without discriminant not detected")
2313                .to_string(),
2314            "Enumeration variants must have a discriminant or a #[bilrost(..)] attribute with a \
2315            constant value"
2316        );
2317    }
2318
2319    #[test]
2320    fn test_rejects_empty_enumeration() {
2321        let output = try_enumeration(quote!(
2322            enum X {}
2323        ));
2324        assert_eq!(
2325            output
2326                .expect_err("enumeration without variants not detected")
2327                .to_string(),
2328            "Enumerations must have at least one variant"
2329        );
2330    }
2331
2332    #[test]
2333    fn test_accepts_distinguished_and_borrowed_messages() {
2334        _ = try_message(quote!(
2335            #[bilrost(distinguished, borrowed_only)]
2336            struct DistinguishedBorrowedMessage {
2337                name: &str,
2338            }
2339        ))
2340        .unwrap();
2341        _ = try_message(quote!(
2342            #[bilrost(distinguished, borrowed_only)]
2343            enum DistinguishedBorrowedOneof {
2344                Empty,
2345                #[bilrost(1)]
2346                Name(&str),
2347            }
2348        ))
2349        .unwrap();
2350    }
2351
2352    #[test]
2353    fn test_accepts_distinguished_and_borrowed_oneofs() {
2354        _ = try_oneof(quote!(
2355            #[bilrost(distinguished, borrowed_only)]
2356            enum DistinguishedBorrowedOneof {
2357                #[bilrost(1)]
2358                Name(&str),
2359            }
2360        ))
2361        .unwrap();
2362    }
2363
2364    #[test]
2365    fn test_rejects_duplicated_message_attrs() {
2366        let output = try_message(quote!(
2367            #[bilrost(distinguished, distinguished)]
2368            struct DistinguishedBorrowedMessage {
2369                name: &str,
2370            }
2371        ));
2372        assert_eq!(
2373            output
2374                .expect_err("message with duplicated distinguished attributes not detected")
2375                .to_string(),
2376            "duplicated distinguished attributes"
2377        );
2378        let output = try_message(quote!(
2379            #[bilrost(borrowed_only, distinguished, borrowed_only)]
2380            struct DistinguishedBorrowedMessage {
2381                name: &str,
2382            }
2383        ));
2384        assert_eq!(
2385            output
2386                .expect_err("message with duplicated borrowed_only attributes not detected")
2387                .to_string(),
2388            "duplicated borrowed_only attributes"
2389        );
2390
2391        let output = try_message(quote!(
2392            #[bilrost(distinguished, distinguished)]
2393            enum DistinguishedBorrowedOneof {
2394                Empty,
2395                #[bilrost(1)]
2396                Name(&str),
2397            }
2398        ));
2399        assert_eq!(
2400            output
2401                .expect_err("message with duplicated distinguished attributes not detected")
2402                .to_string(),
2403            "duplicated distinguished attributes"
2404        );
2405        let output = try_message(quote!(
2406            #[bilrost(borrowed_only, distinguished, borrowed_only)]
2407            enum DistinguishedBorrowedOneof {
2408                Empty,
2409                #[bilrost(1)]
2410                Name(&str),
2411            }
2412        ));
2413        assert_eq!(
2414            output
2415                .expect_err("message with duplicated borrowed_only attributes not detected")
2416                .to_string(),
2417            "duplicated borrowed_only attributes"
2418        );
2419    }
2420
2421    #[test]
2422    fn test_rejects_duplicated_oneof_attrs() {
2423        let output = try_message(quote!(
2424            #[bilrost(distinguished, distinguished)]
2425            enum DistinguishedBorrowedOneof {
2426                #[bilrost(1)]
2427                Name(&str),
2428            }
2429        ));
2430        assert_eq!(
2431            output
2432                .expect_err("message with duplicated distinguished attributes not detected")
2433                .to_string(),
2434            "duplicated distinguished attributes"
2435        );
2436        let output = try_message(quote!(
2437            #[bilrost(borrowed_only, distinguished, borrowed_only)]
2438            enum DistinguishedBorrowedOneof {
2439                #[bilrost(1)]
2440                Name(&str),
2441            }
2442        ));
2443        assert_eq!(
2444            output
2445                .expect_err("message with duplicated borrowed_only attributes not detected")
2446                .to_string(),
2447            "duplicated borrowed_only attributes"
2448        );
2449    }
2450
2451    #[test]
2452    fn test_rejects_duplicated_ignore_attrs() {
2453        let output = try_message(quote!(
2454            struct VeryIgnored {
2455                #[bilrost(ignore, ignore)]
2456                what: u32,
2457            }
2458        ));
2459        assert_eq!(
2460            output
2461                .expect_err("field with duplicated ignore attributes not detected")
2462                .root_cause()
2463                .to_string(),
2464            "invalid field what: duplicated ignore attributes for field: ignore , ignore"
2465        );
2466        let output = try_message(quote!(
2467            struct MixedIgnores {
2468                #[bilrost(tag(123), ignore, ignore)]
2469                what: u32,
2470            }
2471        ));
2472        assert_eq!(
2473            output
2474                .expect_err("field with duplicated ignore attributes not detected")
2475                .root_cause()
2476                .to_string(),
2477            "invalid field what: duplicated ignore attributes for field: tag (123) , ignore , \
2478            ignore"
2479        );
2480        let output = try_message(quote!(
2481            #[bilrost(default_per_field, default_per_field)]
2482            struct DuplicatedOnStruct {}
2483        ));
2484        assert_eq!(
2485            output
2486                .expect_err("field with duplicated ignore attributes not detected")
2487                .to_string(),
2488            "duplicated default_per_field attributes"
2489        );
2490    }
2491}