1#![doc(html_root_url = "https://docs.rs/prost-derive/0.14.4")]
2#![recursion_limit = "4096"]
4
5extern crate alloc;
6extern crate proc_macro;
7
8use anyhow::{bail, Context, Error};
9use itertools::Itertools;
10use proc_macro2::{Span, TokenStream};
11use quote::quote;
12use syn::{
13 punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, ExprLit, Fields,
14 FieldsNamed, FieldsUnnamed, Ident, Index, Variant,
15};
16use syn::{Attribute, Lit, Meta, MetaNameValue, Path, Token};
17
18mod field;
19use crate::field::Field;
20
21use self::field::set_option;
22
23fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
24 let input: DeriveInput = syn::parse2(input)?;
25 let ident = input.ident;
26
27 let Attributes {
28 skip_debug,
29 prost_path,
30 } = Attributes::new(input.attrs)?;
31
32 let variant_data = match input.data {
33 Data::Struct(variant_data) => variant_data,
34 Data::Enum(..) => bail!("Message can not be derived for an enum"),
35 Data::Union(..) => bail!("Message can not be derived for a union"),
36 };
37
38 let generics = &input.generics;
39 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
40
41 let (is_struct, fields) = match variant_data {
42 DataStruct {
43 fields: Fields::Named(FieldsNamed { named: fields, .. }),
44 ..
45 } => (true, fields.into_iter().collect()),
46 DataStruct {
47 fields:
48 Fields::Unnamed(FieldsUnnamed {
49 unnamed: fields, ..
50 }),
51 ..
52 } => (false, fields.into_iter().collect()),
53 DataStruct {
54 fields: Fields::Unit,
55 ..
56 } => (false, Vec::new()),
57 };
58
59 let mut next_tag: u32 = 1;
60 let mut fields = fields
61 .into_iter()
62 .enumerate()
63 .flat_map(|(i, field)| {
64 let field_ident = field.ident.map(|x| quote!(#x)).unwrap_or_else(|| {
65 let index = Index {
66 index: i as u32,
67 span: Span::call_site(),
68 };
69 quote!(#index)
70 });
71 match Field::new(field.attrs, Some(next_tag)) {
72 Ok(Some(field)) => {
73 next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag);
74 Some(Ok((field_ident, field)))
75 }
76 Ok(None) => None,
77 Err(err) => Some(Err(
78 err.context(format!("invalid message field {ident}.{field_ident}"))
79 )),
80 }
81 })
82 .collect::<Result<Vec<_>, _>>()?;
83
84 let unsorted_fields = fields.clone();
86
87 fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap());
92 let fields = fields;
93
94 if let Some(duplicate_tag) = fields
95 .iter()
96 .flat_map(|(_, field)| field.tags())
97 .duplicates()
98 .next()
99 {
100 bail!("message {ident} has multiple fields with tag {duplicate_tag}",)
101 };
102
103 let encoded_len = fields
104 .iter()
105 .map(|(field_ident, field)| field.encoded_len(&prost_path, quote!(self.#field_ident)));
106
107 let encode = fields
108 .iter()
109 .map(|(field_ident, field)| field.encode(&prost_path, quote!(self.#field_ident)));
110
111 let merge = fields.iter().map(|(field_ident, field)| {
112 let merge = field.merge(&prost_path, quote!(value));
113 let tags = field.tags().into_iter().map(|tag| quote!(#tag));
114 let tags = Itertools::intersperse(tags, quote!(|));
115
116 quote! {
117 #(#tags)* => {
118 let mut value = &mut self.#field_ident;
119 #merge.map_err(|mut error| {
120 error.push(STRUCT_NAME, stringify!(#field_ident));
121 error
122 })
123 },
124 }
125 });
126
127 let struct_name = if fields.is_empty() {
128 quote!()
129 } else {
130 quote!(
131 const STRUCT_NAME: &'static str = stringify!(#ident);
132 )
133 };
134
135 let clear = fields
136 .iter()
137 .map(|(field_ident, field)| field.clear(quote!(self.#field_ident)));
138
139 let default = if is_struct {
140 let default = fields.iter().map(|(field_ident, field)| {
141 let value = field.default(&prost_path);
142 quote!(#field_ident: #value,)
143 });
144 quote! {#ident {
145 #(#default)*
146 }}
147 } else {
148 let default = fields.iter().map(|(_, field)| {
149 let value = field.default(&prost_path);
150 quote!(#value,)
151 });
152 quote! {#ident (
153 #(#default)*
154 )}
155 };
156
157 let methods = fields
158 .iter()
159 .flat_map(|(field_ident, field)| field.methods(&prost_path, field_ident))
160 .collect::<Vec<_>>();
161 let methods = if methods.is_empty() {
162 quote!()
163 } else {
164 quote! {
165 #[allow(dead_code)]
166 impl #impl_generics #ident #ty_generics #where_clause {
167 #(#methods)*
168 }
169 }
170 };
171
172 let expanded = quote! {
173 impl #impl_generics #prost_path::Message for #ident #ty_generics #where_clause {
174 #[allow(unused_variables)]
175 fn encode_raw(&self, buf: &mut impl #prost_path::bytes::BufMut) {
176 #(#encode)*
177 }
178
179 #[allow(unused_variables)]
180 fn merge_field(
181 &mut self,
182 tag: u32,
183 wire_type: #prost_path::encoding::wire_type::WireType,
184 buf: &mut impl #prost_path::bytes::Buf,
185 ctx: #prost_path::encoding::DecodeContext,
186 ) -> ::core::result::Result<(), #prost_path::DecodeError>
187 {
188 #struct_name
189 match tag {
190 #(#merge)*
191 _ => #prost_path::encoding::skip_field(wire_type, tag, buf, ctx),
192 }
193 }
194
195 #[inline]
196 fn encoded_len(&self) -> usize {
197 0 #(+ #encoded_len)*
198 }
199
200 fn clear(&mut self) {
201 #(#clear;)*
202 }
203 }
204
205 impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
206 fn default() -> Self {
207 #default
208 }
209 }
210 };
211 let expanded = if skip_debug {
212 expanded
213 } else {
214 let debugs = unsorted_fields.iter().map(|(field_ident, field)| {
215 let wrapper = field.debug(&prost_path, quote!(self.#field_ident));
216 let call = if is_struct {
217 quote!(builder.field(stringify!(#field_ident), &wrapper))
218 } else {
219 quote!(builder.field(&wrapper))
220 };
221 quote! {
222 let builder = {
223 let wrapper = #wrapper;
224 #call
225 };
226 }
227 });
228 let debug_builder = if is_struct {
229 quote!(f.debug_struct(stringify!(#ident)))
230 } else {
231 quote!(f.debug_tuple(stringify!(#ident)))
232 };
233 quote! {
234 #expanded
235
236 impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
237 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
238 let mut builder = #debug_builder;
239 #(#debugs;)*
240 builder.finish()
241 }
242 }
243 }
244 };
245
246 let expanded = quote! {
247 #expanded
248
249 #methods
250 };
251
252 Ok(expanded)
253}
254
255#[proc_macro_derive(Message, attributes(prost))]
256pub fn message(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
257 try_message(input.into()).unwrap().into()
258}
259
260fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
261 let input: DeriveInput = syn::parse2(input)?;
262 let ident = input.ident;
263
264 let Attributes { prost_path, .. } = Attributes::new(input.attrs)?;
265
266 let generics = &input.generics;
267 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
268
269 let punctuated_variants = match input.data {
270 Data::Enum(DataEnum { variants, .. }) => variants,
271 Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
272 Data::Union(..) => bail!("Enumeration can not be derived for a union"),
273 };
274
275 let mut variants: Vec<(Ident, Expr, Option<TokenStream>)> = Vec::new();
277 for Variant {
278 attrs,
279 ident,
280 fields,
281 discriminant,
282 ..
283 } in punctuated_variants
284 {
285 match fields {
286 Fields::Unit => (),
287 Fields::Named(_) | Fields::Unnamed(_) => {
288 bail!("Enumeration variants may not have fields")
289 }
290 }
291 match discriminant {
292 Some((_, expr)) => {
293 let deprecated_attr = if attrs.iter().any(|v| v.path().is_ident("deprecated")) {
294 Some(quote!(#[allow(deprecated)]))
295 } else {
296 None
297 };
298 variants.push((ident, expr, deprecated_attr))
299 }
300 None => bail!("Enumeration variants must have a discriminant"),
301 }
302 }
303
304 if variants.is_empty() {
305 panic!("Enumeration must have at least one variant");
306 }
307
308 let (default, _, default_deprecated) = variants[0].clone();
309
310 let is_valid = variants.iter().map(|(_, value, _)| quote!(#value => true));
311 let from = variants
312 .iter()
313 .map(|(variant, value, deprecated)| quote!(#value => ::core::option::Option::Some(#deprecated #ident::#variant)));
314
315 let try_from = variants
316 .iter()
317 .map(|(variant, value, deprecated)| quote!(#value => ::core::result::Result::Ok(#deprecated #ident::#variant)));
318
319 let is_valid_doc = format!("Returns `true` if `value` is a variant of `{ident}`.");
320 let from_i32_doc =
321 format!("Converts an `i32` to a `{ident}`, or `None` if `value` is not a valid variant.");
322
323 let expanded = quote! {
324 impl #impl_generics #ident #ty_generics #where_clause {
325 #[doc=#is_valid_doc]
326 pub const fn is_valid(value: i32) -> bool {
327 match value {
328 #(#is_valid,)*
329 _ => false,
330 }
331 }
332
333 #[deprecated = "Use the TryFrom<i32> implementation instead"]
334 #[doc=#from_i32_doc]
335 pub fn from_i32(value: i32) -> ::core::option::Option<#ident> {
336 match value {
337 #(#from,)*
338 _ => ::core::option::Option::None,
339 }
340 }
341 }
342
343 impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
344 fn default() -> #ident {
345 #default_deprecated #ident::#default
346 }
347 }
348
349 impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause {
350 fn from(value: #ident) -> i32 {
351 value as i32
352 }
353 }
354
355 impl #impl_generics ::core::convert::TryFrom::<i32> for #ident #ty_generics #where_clause {
356 type Error = #prost_path::UnknownEnumValue;
357
358 fn try_from(value: i32) -> ::core::result::Result<#ident, #prost_path::UnknownEnumValue> {
359 match value {
360 #(#try_from,)*
361 _ => ::core::result::Result::Err(#prost_path::UnknownEnumValue(value)),
362 }
363 }
364 }
365 };
366
367 Ok(expanded)
368}
369
370#[proc_macro_derive(Enumeration, attributes(prost))]
371pub fn enumeration(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
372 try_enumeration(input.into()).unwrap().into()
373}
374
375fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
376 let input: DeriveInput = syn::parse2(input)?;
377
378 let ident = input.ident;
379
380 let Attributes {
381 skip_debug,
382 prost_path,
383 } = Attributes::new(input.attrs)?;
384
385 let variants = match input.data {
386 Data::Enum(DataEnum { variants, .. }) => variants,
387 Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
388 Data::Union(..) => bail!("Oneof can not be derived for a union"),
389 };
390
391 let generics = &input.generics;
392 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
393
394 let mut fields: Vec<(Ident, Field, Option<TokenStream>)> = Vec::new();
396 for Variant {
397 attrs,
398 ident: variant_ident,
399 fields: variant_fields,
400 ..
401 } in variants
402 {
403 let variant_fields = match variant_fields {
404 Fields::Unit => Punctuated::new(),
405 Fields::Named(FieldsNamed { named: fields, .. })
406 | Fields::Unnamed(FieldsUnnamed {
407 unnamed: fields, ..
408 }) => fields,
409 };
410 if variant_fields.len() != 1 {
411 bail!("Oneof enum variants must have a single field");
412 }
413 let deprecated_attr = if attrs.iter().any(|v| v.path().is_ident("deprecated")) {
414 Some(quote!(#[allow(deprecated)]))
415 } else {
416 None
417 };
418 match Field::new_oneof(attrs)? {
419 Some(field) => fields.push((variant_ident, field, deprecated_attr)),
420 None => bail!("invalid oneof variant: oneof variants may not be ignored"),
421 }
422 }
423
424 assert!(fields.iter().all(|(_, field, _)| field.tags().len() == 1));
427
428 if let Some(duplicate_tag) = fields
429 .iter()
430 .flat_map(|(_, field, _)| field.tags())
431 .duplicates()
432 .next()
433 {
434 bail!("invalid oneof {ident}: multiple variants have tag {duplicate_tag}");
435 }
436
437 let encode = fields.iter().map(|(variant_ident, field, deprecated)| {
438 let encode = field.encode(&prost_path, quote!(*value));
439 quote!(#deprecated #ident::#variant_ident(ref value) => { #encode })
440 });
441
442 let merge = fields.iter().map(|(variant_ident, field, deprecated)| {
443 let tag = field.tags()[0];
444 let merge = field.merge(&prost_path, quote!(value));
445 quote! {
446 #deprecated
447 #tag => if let ::core::option::Option::Some(#ident::#variant_ident(value)) = field {
448 #merge
449 } else {
450 let mut owned_value = ::core::default::Default::default();
451 let value = &mut owned_value;
452 #merge.map(|_| *field = ::core::option::Option::Some(#deprecated #ident::#variant_ident(owned_value)))
453 }
454 }
455 });
456
457 let encoded_len = fields.iter().map(|(variant_ident, field, deprecated)| {
458 let encoded_len = field.encoded_len(&prost_path, quote!(*value));
459 quote!(#deprecated #ident::#variant_ident(ref value) => #encoded_len)
460 });
461
462 let expanded = quote! {
463 impl #impl_generics #ident #ty_generics #where_clause {
464 pub fn encode(&self, buf: &mut impl #prost_path::bytes::BufMut) {
466 match *self {
467 #(#encode,)*
468 }
469 }
470
471 pub fn merge(
473 field: &mut ::core::option::Option<#ident #ty_generics>,
474 tag: u32,
475 wire_type: #prost_path::encoding::wire_type::WireType,
476 buf: &mut impl #prost_path::bytes::Buf,
477 ctx: #prost_path::encoding::DecodeContext,
478 ) -> ::core::result::Result<(), #prost_path::DecodeError>
479 {
480 match tag {
481 #(#merge,)*
482 _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag),
483 }
484 }
485
486 #[inline]
488 pub fn encoded_len(&self) -> usize {
489 match *self {
490 #(#encoded_len,)*
491 }
492 }
493 }
494
495 };
496 let expanded = if skip_debug {
497 expanded
498 } else {
499 let debug = fields.iter().map(|(variant_ident, field, deprecated)| {
500 let wrapper = field.debug(&prost_path, quote!(*value));
501 quote!(#deprecated #ident::#variant_ident(ref value) => {
502 let wrapper = #wrapper;
503 f.debug_tuple(stringify!(#variant_ident))
504 .field(&wrapper)
505 .finish()
506 })
507 });
508 quote! {
509 #expanded
510
511 impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
512 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
513 match *self {
514 #(#debug,)*
515 }
516 }
517 }
518 }
519 };
520
521 Ok(expanded)
522}
523
524#[proc_macro_derive(Oneof, attributes(prost))]
525pub fn oneof(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
526 try_oneof(input.into()).unwrap().into()
527}
528
529fn prost_attrs(attrs: Vec<Attribute>) -> Result<Vec<Meta>, Error> {
531 let mut result = Vec::new();
532 for attr in attrs.iter() {
533 if let Meta::List(meta_list) = &attr.meta {
534 if meta_list.path.is_ident("prost") {
535 result.extend(
536 meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?,
537 )
538 }
539 }
540 }
541 Ok(result)
542}
543
544fn get_prost_path(attrs: &[Meta]) -> Result<Path, Error> {
547 let mut prost_path = None;
548
549 for attr in attrs {
550 match attr {
551 Meta::NameValue(MetaNameValue {
552 path,
553 value:
554 Expr::Lit(ExprLit {
555 lit: Lit::Str(lit), ..
556 }),
557 ..
558 }) if path.is_ident("prost_path") => {
559 let path: Path =
560 syn::parse_str(&lit.value()).context("invalid prost_path argument")?;
561
562 set_option(&mut prost_path, path, "duplicate prost_path attributes")?;
563 }
564 _ => continue,
565 }
566 }
567
568 let prost_path =
569 prost_path.unwrap_or_else(|| syn::parse_str("::prost").expect("default prost_path"));
570
571 Ok(prost_path)
572}
573
574struct Attributes {
575 skip_debug: bool,
576 prost_path: Path,
577}
578
579impl Attributes {
580 fn new(attrs: Vec<Attribute>) -> Result<Self, Error> {
581 syn::custom_keyword!(skip_debug);
582 let skip_debug = attrs.iter().any(|a| a.parse_args::<skip_debug>().is_ok());
583
584 let attrs = prost_attrs(attrs)?;
585 let prost_path = get_prost_path(&attrs)?;
586
587 Ok(Self {
588 skip_debug,
589 prost_path,
590 })
591 }
592}
593
594#[cfg(test)]
595mod test {
596 use crate::{try_message, try_oneof};
597 use quote::quote;
598
599 #[test]
600 fn test_rejects_colliding_message_fields() {
601 let output = try_message(quote!(
602 struct Invalid {
603 #[prost(bool, tag = "1")]
604 a: bool,
605 #[prost(oneof = "super::Whatever", tags = "4, 5, 1")]
606 b: Option<super::Whatever>,
607 }
608 ));
609 assert_eq!(
610 output
611 .expect_err("did not reject colliding message fields")
612 .to_string(),
613 "message Invalid has multiple fields with tag 1"
614 );
615 }
616
617 #[test]
618 fn test_rejects_colliding_oneof_variants() {
619 let output = try_oneof(quote!(
620 pub enum Invalid {
621 #[prost(bool, tag = "1")]
622 A(bool),
623 #[prost(bool, tag = "3")]
624 B(bool),
625 #[prost(bool, tag = "1")]
626 C(bool),
627 }
628 ));
629 assert_eq!(
630 output
631 .expect_err("did not reject colliding oneof variants")
632 .to_string(),
633 "invalid oneof Invalid: multiple variants have tag 1"
634 );
635 }
636
637 #[test]
638 fn test_rejects_multiple_tags_oneof_variant() {
639 let output = try_oneof(quote!(
640 enum What {
641 #[prost(bool, tag = "1", tag = "2")]
642 A(bool),
643 }
644 ));
645 assert_eq!(
646 output
647 .expect_err("did not reject multiple tags on oneof variant")
648 .to_string(),
649 "duplicate tag attributes: 1 and 2"
650 );
651
652 let output = try_oneof(quote!(
653 enum What {
654 #[prost(bool, tag = "3")]
655 #[prost(tag = "4")]
656 A(bool),
657 }
658 ));
659 assert!(output.is_err());
660 assert_eq!(
661 output
662 .expect_err("did not reject multiple tags on oneof variant")
663 .to_string(),
664 "duplicate tag attributes: 3 and 4"
665 );
666
667 let output = try_oneof(quote!(
668 enum What {
669 #[prost(bool, tags = "5,6")]
670 A(bool),
671 }
672 ));
673 assert!(output.is_err());
674 assert_eq!(
675 output
676 .expect_err("did not reject multiple tags on oneof variant")
677 .to_string(),
678 "unknown attribute(s): #[prost(tags = \"5,6\")]"
679 );
680 }
681}