prost_derive/
lib.rs

1#![doc(html_root_url = "https://docs.rs/prost-derive/0.14.2")]
2// The `quote!` macro requires deep recursion.
3#![recursion_limit = "4096"]
4
5extern crate alloc;
6extern crate proc_macro;
7
8use anyhow::{bail, Context, Error};
9use itertools::Itertools;
10use proc_macro2::{Span, TokenStream};
11use quote::quote;
12use syn::{
13    punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, ExprLit, Fields,
14    FieldsNamed, FieldsUnnamed, Ident, Index, Variant,
15};
16use syn::{Attribute, Lit, Meta, MetaNameValue, Path, Token};
17
18mod field;
19use crate::field::Field;
20
21use self::field::set_option;
22
23fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
24    let input: DeriveInput = syn::parse2(input)?;
25    let ident = input.ident;
26
27    let Attributes {
28        skip_debug,
29        prost_path,
30    } = Attributes::new(input.attrs)?;
31
32    let variant_data = match input.data {
33        Data::Struct(variant_data) => variant_data,
34        Data::Enum(..) => bail!("Message can not be derived for an enum"),
35        Data::Union(..) => bail!("Message can not be derived for a union"),
36    };
37
38    let generics = &input.generics;
39    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
40
41    let (is_struct, fields) = match variant_data {
42        DataStruct {
43            fields: Fields::Named(FieldsNamed { named: fields, .. }),
44            ..
45        } => (true, fields.into_iter().collect()),
46        DataStruct {
47            fields:
48                Fields::Unnamed(FieldsUnnamed {
49                    unnamed: fields, ..
50                }),
51            ..
52        } => (false, fields.into_iter().collect()),
53        DataStruct {
54            fields: Fields::Unit,
55            ..
56        } => (false, Vec::new()),
57    };
58
59    let mut next_tag: u32 = 1;
60    let mut fields = fields
61        .into_iter()
62        .enumerate()
63        .flat_map(|(i, field)| {
64            let field_ident = field.ident.map(|x| quote!(#x)).unwrap_or_else(|| {
65                let index = Index {
66                    index: i as u32,
67                    span: Span::call_site(),
68                };
69                quote!(#index)
70            });
71            match Field::new(field.attrs, Some(next_tag)) {
72                Ok(Some(field)) => {
73                    next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag);
74                    Some(Ok((field_ident, field)))
75                }
76                Ok(None) => None,
77                Err(err) => Some(Err(
78                    err.context(format!("invalid message field {ident}.{field_ident}"))
79                )),
80            }
81        })
82        .collect::<Result<Vec<_>, _>>()?;
83
84    // We want Debug to be in declaration order
85    let unsorted_fields = fields.clone();
86
87    // Sort the fields by tag number so that fields will be encoded in tag order.
88    // TODO: This encodes oneof fields in the position of their lowest tag,
89    // regardless of the currently occupied variant, is that consequential?
90    // See: https://protobuf.dev/programming-guides/encoding/#order
91    fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap());
92    let fields = fields;
93
94    if let Some(duplicate_tag) = fields
95        .iter()
96        .flat_map(|(_, field)| field.tags())
97        .duplicates()
98        .next()
99    {
100        bail!("message {ident} has multiple fields with tag {duplicate_tag}",)
101    };
102
103    let encoded_len = fields
104        .iter()
105        .map(|(field_ident, field)| field.encoded_len(&prost_path, quote!(self.#field_ident)));
106
107    let encode = fields
108        .iter()
109        .map(|(field_ident, field)| field.encode(&prost_path, quote!(self.#field_ident)));
110
111    let merge = fields.iter().map(|(field_ident, field)| {
112        let merge = field.merge(&prost_path, quote!(value));
113        let tags = field.tags().into_iter().map(|tag| quote!(#tag));
114        let tags = Itertools::intersperse(tags, quote!(|));
115
116        quote! {
117            #(#tags)* => {
118                let mut value = &mut self.#field_ident;
119                #merge.map_err(|mut error| {
120                    error.push(STRUCT_NAME, stringify!(#field_ident));
121                    error
122                })
123            },
124        }
125    });
126
127    let struct_name = if fields.is_empty() {
128        quote!()
129    } else {
130        quote!(
131            const STRUCT_NAME: &'static str = stringify!(#ident);
132        )
133    };
134
135    let clear = fields
136        .iter()
137        .map(|(field_ident, field)| field.clear(quote!(self.#field_ident)));
138
139    let default = if is_struct {
140        let default = fields.iter().map(|(field_ident, field)| {
141            let value = field.default(&prost_path);
142            quote!(#field_ident: #value,)
143        });
144        quote! {#ident {
145            #(#default)*
146        }}
147    } else {
148        let default = fields.iter().map(|(_, field)| {
149            let value = field.default(&prost_path);
150            quote!(#value,)
151        });
152        quote! {#ident (
153            #(#default)*
154        )}
155    };
156
157    let methods = fields
158        .iter()
159        .flat_map(|(field_ident, field)| field.methods(&prost_path, field_ident))
160        .collect::<Vec<_>>();
161    let methods = if methods.is_empty() {
162        quote!()
163    } else {
164        quote! {
165            #[allow(dead_code)]
166            impl #impl_generics #ident #ty_generics #where_clause {
167                #(#methods)*
168            }
169        }
170    };
171
172    let expanded = quote! {
173        impl #impl_generics #prost_path::Message for #ident #ty_generics #where_clause {
174            #[allow(unused_variables)]
175            fn encode_raw(&self, buf: &mut impl #prost_path::bytes::BufMut) {
176                #(#encode)*
177            }
178
179            #[allow(unused_variables)]
180            fn merge_field(
181                &mut self,
182                tag: u32,
183                wire_type: #prost_path::encoding::wire_type::WireType,
184                buf: &mut impl #prost_path::bytes::Buf,
185                ctx: #prost_path::encoding::DecodeContext,
186            ) -> ::core::result::Result<(), #prost_path::DecodeError>
187            {
188                #struct_name
189                match tag {
190                    #(#merge)*
191                    _ => #prost_path::encoding::skip_field(wire_type, tag, buf, ctx),
192                }
193            }
194
195            #[inline]
196            fn encoded_len(&self) -> usize {
197                0 #(+ #encoded_len)*
198            }
199
200            fn clear(&mut self) {
201                #(#clear;)*
202            }
203        }
204
205        impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
206            fn default() -> Self {
207                #default
208            }
209        }
210    };
211    let expanded = if skip_debug {
212        expanded
213    } else {
214        let debugs = unsorted_fields.iter().map(|(field_ident, field)| {
215            let wrapper = field.debug(&prost_path, quote!(self.#field_ident));
216            let call = if is_struct {
217                quote!(builder.field(stringify!(#field_ident), &wrapper))
218            } else {
219                quote!(builder.field(&wrapper))
220            };
221            quote! {
222                 let builder = {
223                     let wrapper = #wrapper;
224                     #call
225                 };
226            }
227        });
228        let debug_builder = if is_struct {
229            quote!(f.debug_struct(stringify!(#ident)))
230        } else {
231            quote!(f.debug_tuple(stringify!(#ident)))
232        };
233        quote! {
234            #expanded
235
236            impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
237                fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
238                    let mut builder = #debug_builder;
239                    #(#debugs;)*
240                    builder.finish()
241                }
242            }
243        }
244    };
245
246    let expanded = quote! {
247        #expanded
248
249        #methods
250    };
251
252    Ok(expanded)
253}
254
255#[proc_macro_derive(Message, attributes(prost))]
256pub fn message(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
257    try_message(input.into()).unwrap().into()
258}
259
260fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
261    let input: DeriveInput = syn::parse2(input)?;
262    let ident = input.ident;
263
264    let Attributes { prost_path, .. } = Attributes::new(input.attrs)?;
265
266    let generics = &input.generics;
267    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
268
269    let punctuated_variants = match input.data {
270        Data::Enum(DataEnum { variants, .. }) => variants,
271        Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
272        Data::Union(..) => bail!("Enumeration can not be derived for a union"),
273    };
274
275    // Map the variants into 'fields'.
276    let mut variants: Vec<(Ident, Expr, Option<TokenStream>)> = Vec::new();
277    for Variant {
278        attrs,
279        ident,
280        fields,
281        discriminant,
282        ..
283    } in punctuated_variants
284    {
285        match fields {
286            Fields::Unit => (),
287            Fields::Named(_) | Fields::Unnamed(_) => {
288                bail!("Enumeration variants may not have fields")
289            }
290        }
291        match discriminant {
292            Some((_, expr)) => {
293                let deprecated_attr = if attrs.iter().any(|v| v.path().is_ident("deprecated")) {
294                    Some(quote!(#[allow(deprecated)]))
295                } else {
296                    None
297                };
298                variants.push((ident, expr, deprecated_attr))
299            }
300            None => bail!("Enumeration variants must have a discriminant"),
301        }
302    }
303
304    if variants.is_empty() {
305        panic!("Enumeration must have at least one variant");
306    }
307
308    let (default, _, default_deprecated) = variants[0].clone();
309
310    let is_valid = variants.iter().map(|(_, value, _)| quote!(#value => true));
311    let from = variants
312        .iter()
313        .map(|(variant, value, deprecated)| quote!(#value => ::core::option::Option::Some(#deprecated #ident::#variant)));
314
315    let try_from = variants
316        .iter()
317        .map(|(variant, value, deprecated)| quote!(#value => ::core::result::Result::Ok(#deprecated #ident::#variant)));
318
319    let is_valid_doc = format!("Returns `true` if `value` is a variant of `{ident}`.");
320    let from_i32_doc =
321        format!("Converts an `i32` to a `{ident}`, or `None` if `value` is not a valid variant.");
322
323    let expanded = quote! {
324        impl #impl_generics #ident #ty_generics #where_clause {
325            #[doc=#is_valid_doc]
326            pub fn is_valid(value: i32) -> bool {
327                match value {
328                    #(#is_valid,)*
329                    _ => false,
330                }
331            }
332
333            #[deprecated = "Use the TryFrom<i32> implementation instead"]
334            #[doc=#from_i32_doc]
335            pub fn from_i32(value: i32) -> ::core::option::Option<#ident> {
336                match value {
337                    #(#from,)*
338                    _ => ::core::option::Option::None,
339                }
340            }
341        }
342
343        impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
344            fn default() -> #ident {
345                #default_deprecated #ident::#default
346            }
347        }
348
349        impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause {
350            fn from(value: #ident) -> i32 {
351                value as i32
352            }
353        }
354
355        impl #impl_generics ::core::convert::TryFrom::<i32> for #ident #ty_generics #where_clause {
356            type Error = #prost_path::UnknownEnumValue;
357
358            fn try_from(value: i32) -> ::core::result::Result<#ident, #prost_path::UnknownEnumValue> {
359                match value {
360                    #(#try_from,)*
361                    _ => ::core::result::Result::Err(#prost_path::UnknownEnumValue(value)),
362                }
363            }
364        }
365    };
366
367    Ok(expanded)
368}
369
370#[proc_macro_derive(Enumeration, attributes(prost))]
371pub fn enumeration(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
372    try_enumeration(input.into()).unwrap().into()
373}
374
375fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
376    let input: DeriveInput = syn::parse2(input)?;
377
378    let ident = input.ident;
379
380    let Attributes {
381        skip_debug,
382        prost_path,
383    } = Attributes::new(input.attrs)?;
384
385    let variants = match input.data {
386        Data::Enum(DataEnum { variants, .. }) => variants,
387        Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
388        Data::Union(..) => bail!("Oneof can not be derived for a union"),
389    };
390
391    let generics = &input.generics;
392    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
393
394    // Map the variants into 'fields'.
395    let mut fields: Vec<(Ident, Field, Option<TokenStream>)> = Vec::new();
396    for Variant {
397        attrs,
398        ident: variant_ident,
399        fields: variant_fields,
400        ..
401    } in variants
402    {
403        let variant_fields = match variant_fields {
404            Fields::Unit => Punctuated::new(),
405            Fields::Named(FieldsNamed { named: fields, .. })
406            | Fields::Unnamed(FieldsUnnamed {
407                unnamed: fields, ..
408            }) => fields,
409        };
410        if variant_fields.len() != 1 {
411            bail!("Oneof enum variants must have a single field");
412        }
413        let deprecated_attr = if attrs.iter().any(|v| v.path().is_ident("deprecated")) {
414            Some(quote!(#[allow(deprecated)]))
415        } else {
416            None
417        };
418        match Field::new_oneof(attrs)? {
419            Some(field) => fields.push((variant_ident, field, deprecated_attr)),
420            None => bail!("invalid oneof variant: oneof variants may not be ignored"),
421        }
422    }
423
424    // Oneof variants cannot be oneofs themselves, so it's impossible to have a field with multiple
425    // tags.
426    assert!(fields.iter().all(|(_, field, _)| field.tags().len() == 1));
427
428    if let Some(duplicate_tag) = fields
429        .iter()
430        .flat_map(|(_, field, _)| field.tags())
431        .duplicates()
432        .next()
433    {
434        bail!("invalid oneof {ident}: multiple variants have tag {duplicate_tag}");
435    }
436
437    let encode = fields.iter().map(|(variant_ident, field, deprecated)| {
438        let encode = field.encode(&prost_path, quote!(*value));
439        quote!(#deprecated #ident::#variant_ident(ref value) => { #encode })
440    });
441
442    let merge = fields.iter().map(|(variant_ident, field, deprecated)| {
443        let tag = field.tags()[0];
444        let merge = field.merge(&prost_path, quote!(value));
445        quote! {
446            #deprecated
447            #tag => if let ::core::option::Option::Some(#ident::#variant_ident(value)) = field {
448                #merge
449            } else {
450                let mut owned_value = ::core::default::Default::default();
451                let value = &mut owned_value;
452                #merge.map(|_| *field = ::core::option::Option::Some(#deprecated #ident::#variant_ident(owned_value)))
453            }
454        }
455    });
456
457    let encoded_len = fields.iter().map(|(variant_ident, field, deprecated)| {
458        let encoded_len = field.encoded_len(&prost_path, quote!(*value));
459        quote!(#deprecated #ident::#variant_ident(ref value) => #encoded_len)
460    });
461
462    let expanded = quote! {
463        impl #impl_generics #ident #ty_generics #where_clause {
464            /// Encodes the message to a buffer.
465            pub fn encode(&self, buf: &mut impl #prost_path::bytes::BufMut) {
466                match *self {
467                    #(#encode,)*
468                }
469            }
470
471            /// Decodes an instance of the message from a buffer, and merges it into self.
472            pub fn merge(
473                field: &mut ::core::option::Option<#ident #ty_generics>,
474                tag: u32,
475                wire_type: #prost_path::encoding::wire_type::WireType,
476                buf: &mut impl #prost_path::bytes::Buf,
477                ctx: #prost_path::encoding::DecodeContext,
478            ) -> ::core::result::Result<(), #prost_path::DecodeError>
479            {
480                match tag {
481                    #(#merge,)*
482                    _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag),
483                }
484            }
485
486            /// Returns the encoded length of the message without a length delimiter.
487            #[inline]
488            pub fn encoded_len(&self) -> usize {
489                match *self {
490                    #(#encoded_len,)*
491                }
492            }
493        }
494
495    };
496    let expanded = if skip_debug {
497        expanded
498    } else {
499        let debug = fields.iter().map(|(variant_ident, field, deprecated)| {
500            let wrapper = field.debug(&prost_path, quote!(*value));
501            quote!(#deprecated #ident::#variant_ident(ref value) => {
502                let wrapper = #wrapper;
503                f.debug_tuple(stringify!(#variant_ident))
504                    .field(&wrapper)
505                    .finish()
506            })
507        });
508        quote! {
509            #expanded
510
511            impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
512                fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
513                    match *self {
514                        #(#debug,)*
515                    }
516                }
517            }
518        }
519    };
520
521    Ok(expanded)
522}
523
524#[proc_macro_derive(Oneof, attributes(prost))]
525pub fn oneof(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
526    try_oneof(input.into()).unwrap().into()
527}
528
529/// Get the items belonging to the 'prost' list attribute, e.g. `#[prost(foo, bar="baz")]`.
530fn prost_attrs(attrs: Vec<Attribute>) -> Result<Vec<Meta>, Error> {
531    let mut result = Vec::new();
532    for attr in attrs.iter() {
533        if let Meta::List(meta_list) = &attr.meta {
534            if meta_list.path.is_ident("prost") {
535                result.extend(
536                    meta_list
537                        .parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?
538                        .into_iter(),
539                )
540            }
541        }
542    }
543    Ok(result)
544}
545
546/// Extracts the path to prost specified using the `#[prost(prost_path = "...")]` attribute. When
547/// missing, falls back to default, which is `::prost`.
548fn get_prost_path(attrs: &[Meta]) -> Result<Path, Error> {
549    let mut prost_path = None;
550
551    for attr in attrs {
552        match attr {
553            Meta::NameValue(MetaNameValue {
554                path,
555                value:
556                    Expr::Lit(ExprLit {
557                        lit: Lit::Str(lit), ..
558                    }),
559                ..
560            }) if path.is_ident("prost_path") => {
561                let path: Path =
562                    syn::parse_str(&lit.value()).context("invalid prost_path argument")?;
563
564                set_option(&mut prost_path, path, "duplicate prost_path attributes")?;
565            }
566            _ => continue,
567        }
568    }
569
570    let prost_path =
571        prost_path.unwrap_or_else(|| syn::parse_str("::prost").expect("default prost_path"));
572
573    Ok(prost_path)
574}
575
576struct Attributes {
577    skip_debug: bool,
578    prost_path: Path,
579}
580
581impl Attributes {
582    fn new(attrs: Vec<Attribute>) -> Result<Self, Error> {
583        syn::custom_keyword!(skip_debug);
584        let skip_debug = attrs.iter().any(|a| a.parse_args::<skip_debug>().is_ok());
585
586        let attrs = prost_attrs(attrs)?;
587        let prost_path = get_prost_path(&attrs)?;
588
589        Ok(Self {
590            skip_debug,
591            prost_path,
592        })
593    }
594}
595
596#[cfg(test)]
597mod test {
598    use crate::{try_message, try_oneof};
599    use quote::quote;
600
601    #[test]
602    fn test_rejects_colliding_message_fields() {
603        let output = try_message(quote!(
604            struct Invalid {
605                #[prost(bool, tag = "1")]
606                a: bool,
607                #[prost(oneof = "super::Whatever", tags = "4, 5, 1")]
608                b: Option<super::Whatever>,
609            }
610        ));
611        assert_eq!(
612            output
613                .expect_err("did not reject colliding message fields")
614                .to_string(),
615            "message Invalid has multiple fields with tag 1"
616        );
617    }
618
619    #[test]
620    fn test_rejects_colliding_oneof_variants() {
621        let output = try_oneof(quote!(
622            pub enum Invalid {
623                #[prost(bool, tag = "1")]
624                A(bool),
625                #[prost(bool, tag = "3")]
626                B(bool),
627                #[prost(bool, tag = "1")]
628                C(bool),
629            }
630        ));
631        assert_eq!(
632            output
633                .expect_err("did not reject colliding oneof variants")
634                .to_string(),
635            "invalid oneof Invalid: multiple variants have tag 1"
636        );
637    }
638
639    #[test]
640    fn test_rejects_multiple_tags_oneof_variant() {
641        let output = try_oneof(quote!(
642            enum What {
643                #[prost(bool, tag = "1", tag = "2")]
644                A(bool),
645            }
646        ));
647        assert_eq!(
648            output
649                .expect_err("did not reject multiple tags on oneof variant")
650                .to_string(),
651            "duplicate tag attributes: 1 and 2"
652        );
653
654        let output = try_oneof(quote!(
655            enum What {
656                #[prost(bool, tag = "3")]
657                #[prost(tag = "4")]
658                A(bool),
659            }
660        ));
661        assert!(output.is_err());
662        assert_eq!(
663            output
664                .expect_err("did not reject multiple tags on oneof variant")
665                .to_string(),
666            "duplicate tag attributes: 3 and 4"
667        );
668
669        let output = try_oneof(quote!(
670            enum What {
671                #[prost(bool, tags = "5,6")]
672                A(bool),
673            }
674        ));
675        assert!(output.is_err());
676        assert_eq!(
677            output
678                .expect_err("did not reject multiple tags on oneof variant")
679                .to_string(),
680            "unknown attribute(s): #[prost(tags = \"5,6\")]"
681        );
682    }
683}