1#![doc(html_root_url = "https://docs.rs/prost-derive/0.14.2")]
2#![recursion_limit = "4096"]
4
5extern crate alloc;
6extern crate proc_macro;
7
8use anyhow::{bail, Context, Error};
9use itertools::Itertools;
10use proc_macro2::{Span, TokenStream};
11use quote::quote;
12use syn::{
13 punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, ExprLit, Fields,
14 FieldsNamed, FieldsUnnamed, Ident, Index, Variant,
15};
16use syn::{Attribute, Lit, Meta, MetaNameValue, Path, Token};
17
18mod field;
19use crate::field::Field;
20
21use self::field::set_option;
22
23fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
24 let input: DeriveInput = syn::parse2(input)?;
25 let ident = input.ident;
26
27 let Attributes {
28 skip_debug,
29 prost_path,
30 } = Attributes::new(input.attrs)?;
31
32 let variant_data = match input.data {
33 Data::Struct(variant_data) => variant_data,
34 Data::Enum(..) => bail!("Message can not be derived for an enum"),
35 Data::Union(..) => bail!("Message can not be derived for a union"),
36 };
37
38 let generics = &input.generics;
39 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
40
41 let (is_struct, fields) = match variant_data {
42 DataStruct {
43 fields: Fields::Named(FieldsNamed { named: fields, .. }),
44 ..
45 } => (true, fields.into_iter().collect()),
46 DataStruct {
47 fields:
48 Fields::Unnamed(FieldsUnnamed {
49 unnamed: fields, ..
50 }),
51 ..
52 } => (false, fields.into_iter().collect()),
53 DataStruct {
54 fields: Fields::Unit,
55 ..
56 } => (false, Vec::new()),
57 };
58
59 let mut next_tag: u32 = 1;
60 let mut fields = fields
61 .into_iter()
62 .enumerate()
63 .flat_map(|(i, field)| {
64 let field_ident = field.ident.map(|x| quote!(#x)).unwrap_or_else(|| {
65 let index = Index {
66 index: i as u32,
67 span: Span::call_site(),
68 };
69 quote!(#index)
70 });
71 match Field::new(field.attrs, Some(next_tag)) {
72 Ok(Some(field)) => {
73 next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag);
74 Some(Ok((field_ident, field)))
75 }
76 Ok(None) => None,
77 Err(err) => Some(Err(
78 err.context(format!("invalid message field {ident}.{field_ident}"))
79 )),
80 }
81 })
82 .collect::<Result<Vec<_>, _>>()?;
83
84 let unsorted_fields = fields.clone();
86
87 fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap());
92 let fields = fields;
93
94 if let Some(duplicate_tag) = fields
95 .iter()
96 .flat_map(|(_, field)| field.tags())
97 .duplicates()
98 .next()
99 {
100 bail!("message {ident} has multiple fields with tag {duplicate_tag}",)
101 };
102
103 let encoded_len = fields
104 .iter()
105 .map(|(field_ident, field)| field.encoded_len(&prost_path, quote!(self.#field_ident)));
106
107 let encode = fields
108 .iter()
109 .map(|(field_ident, field)| field.encode(&prost_path, quote!(self.#field_ident)));
110
111 let merge = fields.iter().map(|(field_ident, field)| {
112 let merge = field.merge(&prost_path, quote!(value));
113 let tags = field.tags().into_iter().map(|tag| quote!(#tag));
114 let tags = Itertools::intersperse(tags, quote!(|));
115
116 quote! {
117 #(#tags)* => {
118 let mut value = &mut self.#field_ident;
119 #merge.map_err(|mut error| {
120 error.push(STRUCT_NAME, stringify!(#field_ident));
121 error
122 })
123 },
124 }
125 });
126
127 let struct_name = if fields.is_empty() {
128 quote!()
129 } else {
130 quote!(
131 const STRUCT_NAME: &'static str = stringify!(#ident);
132 )
133 };
134
135 let clear = fields
136 .iter()
137 .map(|(field_ident, field)| field.clear(quote!(self.#field_ident)));
138
139 let default = if is_struct {
140 let default = fields.iter().map(|(field_ident, field)| {
141 let value = field.default(&prost_path);
142 quote!(#field_ident: #value,)
143 });
144 quote! {#ident {
145 #(#default)*
146 }}
147 } else {
148 let default = fields.iter().map(|(_, field)| {
149 let value = field.default(&prost_path);
150 quote!(#value,)
151 });
152 quote! {#ident (
153 #(#default)*
154 )}
155 };
156
157 let methods = fields
158 .iter()
159 .flat_map(|(field_ident, field)| field.methods(&prost_path, field_ident))
160 .collect::<Vec<_>>();
161 let methods = if methods.is_empty() {
162 quote!()
163 } else {
164 quote! {
165 #[allow(dead_code)]
166 impl #impl_generics #ident #ty_generics #where_clause {
167 #(#methods)*
168 }
169 }
170 };
171
172 let expanded = quote! {
173 impl #impl_generics #prost_path::Message for #ident #ty_generics #where_clause {
174 #[allow(unused_variables)]
175 fn encode_raw(&self, buf: &mut impl #prost_path::bytes::BufMut) {
176 #(#encode)*
177 }
178
179 #[allow(unused_variables)]
180 fn merge_field(
181 &mut self,
182 tag: u32,
183 wire_type: #prost_path::encoding::wire_type::WireType,
184 buf: &mut impl #prost_path::bytes::Buf,
185 ctx: #prost_path::encoding::DecodeContext,
186 ) -> ::core::result::Result<(), #prost_path::DecodeError>
187 {
188 #struct_name
189 match tag {
190 #(#merge)*
191 _ => #prost_path::encoding::skip_field(wire_type, tag, buf, ctx),
192 }
193 }
194
195 #[inline]
196 fn encoded_len(&self) -> usize {
197 0 #(+ #encoded_len)*
198 }
199
200 fn clear(&mut self) {
201 #(#clear;)*
202 }
203 }
204
205 impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
206 fn default() -> Self {
207 #default
208 }
209 }
210 };
211 let expanded = if skip_debug {
212 expanded
213 } else {
214 let debugs = unsorted_fields.iter().map(|(field_ident, field)| {
215 let wrapper = field.debug(&prost_path, quote!(self.#field_ident));
216 let call = if is_struct {
217 quote!(builder.field(stringify!(#field_ident), &wrapper))
218 } else {
219 quote!(builder.field(&wrapper))
220 };
221 quote! {
222 let builder = {
223 let wrapper = #wrapper;
224 #call
225 };
226 }
227 });
228 let debug_builder = if is_struct {
229 quote!(f.debug_struct(stringify!(#ident)))
230 } else {
231 quote!(f.debug_tuple(stringify!(#ident)))
232 };
233 quote! {
234 #expanded
235
236 impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
237 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
238 let mut builder = #debug_builder;
239 #(#debugs;)*
240 builder.finish()
241 }
242 }
243 }
244 };
245
246 let expanded = quote! {
247 #expanded
248
249 #methods
250 };
251
252 Ok(expanded)
253}
254
255#[proc_macro_derive(Message, attributes(prost))]
256pub fn message(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
257 try_message(input.into()).unwrap().into()
258}
259
260fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
261 let input: DeriveInput = syn::parse2(input)?;
262 let ident = input.ident;
263
264 let Attributes { prost_path, .. } = Attributes::new(input.attrs)?;
265
266 let generics = &input.generics;
267 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
268
269 let punctuated_variants = match input.data {
270 Data::Enum(DataEnum { variants, .. }) => variants,
271 Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
272 Data::Union(..) => bail!("Enumeration can not be derived for a union"),
273 };
274
275 let mut variants: Vec<(Ident, Expr, Option<TokenStream>)> = Vec::new();
277 for Variant {
278 attrs,
279 ident,
280 fields,
281 discriminant,
282 ..
283 } in punctuated_variants
284 {
285 match fields {
286 Fields::Unit => (),
287 Fields::Named(_) | Fields::Unnamed(_) => {
288 bail!("Enumeration variants may not have fields")
289 }
290 }
291 match discriminant {
292 Some((_, expr)) => {
293 let deprecated_attr = if attrs.iter().any(|v| v.path().is_ident("deprecated")) {
294 Some(quote!(#[allow(deprecated)]))
295 } else {
296 None
297 };
298 variants.push((ident, expr, deprecated_attr))
299 }
300 None => bail!("Enumeration variants must have a discriminant"),
301 }
302 }
303
304 if variants.is_empty() {
305 panic!("Enumeration must have at least one variant");
306 }
307
308 let (default, _, default_deprecated) = variants[0].clone();
309
310 let is_valid = variants.iter().map(|(_, value, _)| quote!(#value => true));
311 let from = variants
312 .iter()
313 .map(|(variant, value, deprecated)| quote!(#value => ::core::option::Option::Some(#deprecated #ident::#variant)));
314
315 let try_from = variants
316 .iter()
317 .map(|(variant, value, deprecated)| quote!(#value => ::core::result::Result::Ok(#deprecated #ident::#variant)));
318
319 let is_valid_doc = format!("Returns `true` if `value` is a variant of `{ident}`.");
320 let from_i32_doc =
321 format!("Converts an `i32` to a `{ident}`, or `None` if `value` is not a valid variant.");
322
323 let expanded = quote! {
324 impl #impl_generics #ident #ty_generics #where_clause {
325 #[doc=#is_valid_doc]
326 pub fn is_valid(value: i32) -> bool {
327 match value {
328 #(#is_valid,)*
329 _ => false,
330 }
331 }
332
333 #[deprecated = "Use the TryFrom<i32> implementation instead"]
334 #[doc=#from_i32_doc]
335 pub fn from_i32(value: i32) -> ::core::option::Option<#ident> {
336 match value {
337 #(#from,)*
338 _ => ::core::option::Option::None,
339 }
340 }
341 }
342
343 impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
344 fn default() -> #ident {
345 #default_deprecated #ident::#default
346 }
347 }
348
349 impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause {
350 fn from(value: #ident) -> i32 {
351 value as i32
352 }
353 }
354
355 impl #impl_generics ::core::convert::TryFrom::<i32> for #ident #ty_generics #where_clause {
356 type Error = #prost_path::UnknownEnumValue;
357
358 fn try_from(value: i32) -> ::core::result::Result<#ident, #prost_path::UnknownEnumValue> {
359 match value {
360 #(#try_from,)*
361 _ => ::core::result::Result::Err(#prost_path::UnknownEnumValue(value)),
362 }
363 }
364 }
365 };
366
367 Ok(expanded)
368}
369
370#[proc_macro_derive(Enumeration, attributes(prost))]
371pub fn enumeration(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
372 try_enumeration(input.into()).unwrap().into()
373}
374
375fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
376 let input: DeriveInput = syn::parse2(input)?;
377
378 let ident = input.ident;
379
380 let Attributes {
381 skip_debug,
382 prost_path,
383 } = Attributes::new(input.attrs)?;
384
385 let variants = match input.data {
386 Data::Enum(DataEnum { variants, .. }) => variants,
387 Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
388 Data::Union(..) => bail!("Oneof can not be derived for a union"),
389 };
390
391 let generics = &input.generics;
392 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
393
394 let mut fields: Vec<(Ident, Field, Option<TokenStream>)> = Vec::new();
396 for Variant {
397 attrs,
398 ident: variant_ident,
399 fields: variant_fields,
400 ..
401 } in variants
402 {
403 let variant_fields = match variant_fields {
404 Fields::Unit => Punctuated::new(),
405 Fields::Named(FieldsNamed { named: fields, .. })
406 | Fields::Unnamed(FieldsUnnamed {
407 unnamed: fields, ..
408 }) => fields,
409 };
410 if variant_fields.len() != 1 {
411 bail!("Oneof enum variants must have a single field");
412 }
413 let deprecated_attr = if attrs.iter().any(|v| v.path().is_ident("deprecated")) {
414 Some(quote!(#[allow(deprecated)]))
415 } else {
416 None
417 };
418 match Field::new_oneof(attrs)? {
419 Some(field) => fields.push((variant_ident, field, deprecated_attr)),
420 None => bail!("invalid oneof variant: oneof variants may not be ignored"),
421 }
422 }
423
424 assert!(fields.iter().all(|(_, field, _)| field.tags().len() == 1));
427
428 if let Some(duplicate_tag) = fields
429 .iter()
430 .flat_map(|(_, field, _)| field.tags())
431 .duplicates()
432 .next()
433 {
434 bail!("invalid oneof {ident}: multiple variants have tag {duplicate_tag}");
435 }
436
437 let encode = fields.iter().map(|(variant_ident, field, deprecated)| {
438 let encode = field.encode(&prost_path, quote!(*value));
439 quote!(#deprecated #ident::#variant_ident(ref value) => { #encode })
440 });
441
442 let merge = fields.iter().map(|(variant_ident, field, deprecated)| {
443 let tag = field.tags()[0];
444 let merge = field.merge(&prost_path, quote!(value));
445 quote! {
446 #deprecated
447 #tag => if let ::core::option::Option::Some(#ident::#variant_ident(value)) = field {
448 #merge
449 } else {
450 let mut owned_value = ::core::default::Default::default();
451 let value = &mut owned_value;
452 #merge.map(|_| *field = ::core::option::Option::Some(#deprecated #ident::#variant_ident(owned_value)))
453 }
454 }
455 });
456
457 let encoded_len = fields.iter().map(|(variant_ident, field, deprecated)| {
458 let encoded_len = field.encoded_len(&prost_path, quote!(*value));
459 quote!(#deprecated #ident::#variant_ident(ref value) => #encoded_len)
460 });
461
462 let expanded = quote! {
463 impl #impl_generics #ident #ty_generics #where_clause {
464 pub fn encode(&self, buf: &mut impl #prost_path::bytes::BufMut) {
466 match *self {
467 #(#encode,)*
468 }
469 }
470
471 pub fn merge(
473 field: &mut ::core::option::Option<#ident #ty_generics>,
474 tag: u32,
475 wire_type: #prost_path::encoding::wire_type::WireType,
476 buf: &mut impl #prost_path::bytes::Buf,
477 ctx: #prost_path::encoding::DecodeContext,
478 ) -> ::core::result::Result<(), #prost_path::DecodeError>
479 {
480 match tag {
481 #(#merge,)*
482 _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag),
483 }
484 }
485
486 #[inline]
488 pub fn encoded_len(&self) -> usize {
489 match *self {
490 #(#encoded_len,)*
491 }
492 }
493 }
494
495 };
496 let expanded = if skip_debug {
497 expanded
498 } else {
499 let debug = fields.iter().map(|(variant_ident, field, deprecated)| {
500 let wrapper = field.debug(&prost_path, quote!(*value));
501 quote!(#deprecated #ident::#variant_ident(ref value) => {
502 let wrapper = #wrapper;
503 f.debug_tuple(stringify!(#variant_ident))
504 .field(&wrapper)
505 .finish()
506 })
507 });
508 quote! {
509 #expanded
510
511 impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
512 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
513 match *self {
514 #(#debug,)*
515 }
516 }
517 }
518 }
519 };
520
521 Ok(expanded)
522}
523
524#[proc_macro_derive(Oneof, attributes(prost))]
525pub fn oneof(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
526 try_oneof(input.into()).unwrap().into()
527}
528
529fn prost_attrs(attrs: Vec<Attribute>) -> Result<Vec<Meta>, Error> {
531 let mut result = Vec::new();
532 for attr in attrs.iter() {
533 if let Meta::List(meta_list) = &attr.meta {
534 if meta_list.path.is_ident("prost") {
535 result.extend(
536 meta_list
537 .parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?
538 .into_iter(),
539 )
540 }
541 }
542 }
543 Ok(result)
544}
545
546fn get_prost_path(attrs: &[Meta]) -> Result<Path, Error> {
549 let mut prost_path = None;
550
551 for attr in attrs {
552 match attr {
553 Meta::NameValue(MetaNameValue {
554 path,
555 value:
556 Expr::Lit(ExprLit {
557 lit: Lit::Str(lit), ..
558 }),
559 ..
560 }) if path.is_ident("prost_path") => {
561 let path: Path =
562 syn::parse_str(&lit.value()).context("invalid prost_path argument")?;
563
564 set_option(&mut prost_path, path, "duplicate prost_path attributes")?;
565 }
566 _ => continue,
567 }
568 }
569
570 let prost_path =
571 prost_path.unwrap_or_else(|| syn::parse_str("::prost").expect("default prost_path"));
572
573 Ok(prost_path)
574}
575
576struct Attributes {
577 skip_debug: bool,
578 prost_path: Path,
579}
580
581impl Attributes {
582 fn new(attrs: Vec<Attribute>) -> Result<Self, Error> {
583 syn::custom_keyword!(skip_debug);
584 let skip_debug = attrs.iter().any(|a| a.parse_args::<skip_debug>().is_ok());
585
586 let attrs = prost_attrs(attrs)?;
587 let prost_path = get_prost_path(&attrs)?;
588
589 Ok(Self {
590 skip_debug,
591 prost_path,
592 })
593 }
594}
595
596#[cfg(test)]
597mod test {
598 use crate::{try_message, try_oneof};
599 use quote::quote;
600
601 #[test]
602 fn test_rejects_colliding_message_fields() {
603 let output = try_message(quote!(
604 struct Invalid {
605 #[prost(bool, tag = "1")]
606 a: bool,
607 #[prost(oneof = "super::Whatever", tags = "4, 5, 1")]
608 b: Option<super::Whatever>,
609 }
610 ));
611 assert_eq!(
612 output
613 .expect_err("did not reject colliding message fields")
614 .to_string(),
615 "message Invalid has multiple fields with tag 1"
616 );
617 }
618
619 #[test]
620 fn test_rejects_colliding_oneof_variants() {
621 let output = try_oneof(quote!(
622 pub enum Invalid {
623 #[prost(bool, tag = "1")]
624 A(bool),
625 #[prost(bool, tag = "3")]
626 B(bool),
627 #[prost(bool, tag = "1")]
628 C(bool),
629 }
630 ));
631 assert_eq!(
632 output
633 .expect_err("did not reject colliding oneof variants")
634 .to_string(),
635 "invalid oneof Invalid: multiple variants have tag 1"
636 );
637 }
638
639 #[test]
640 fn test_rejects_multiple_tags_oneof_variant() {
641 let output = try_oneof(quote!(
642 enum What {
643 #[prost(bool, tag = "1", tag = "2")]
644 A(bool),
645 }
646 ));
647 assert_eq!(
648 output
649 .expect_err("did not reject multiple tags on oneof variant")
650 .to_string(),
651 "duplicate tag attributes: 1 and 2"
652 );
653
654 let output = try_oneof(quote!(
655 enum What {
656 #[prost(bool, tag = "3")]
657 #[prost(tag = "4")]
658 A(bool),
659 }
660 ));
661 assert!(output.is_err());
662 assert_eq!(
663 output
664 .expect_err("did not reject multiple tags on oneof variant")
665 .to_string(),
666 "duplicate tag attributes: 3 and 4"
667 );
668
669 let output = try_oneof(quote!(
670 enum What {
671 #[prost(bool, tags = "5,6")]
672 A(bool),
673 }
674 ));
675 assert!(output.is_err());
676 assert_eq!(
677 output
678 .expect_err("did not reject multiple tags on oneof variant")
679 .to_string(),
680 "unknown attribute(s): #[prost(tags = \"5,6\")]"
681 );
682 }
683}