prost_reflect_derive/
lib.rs1use proc_macro::TokenStream;
6use proc_macro2::Span;
7use quote::{quote, ToTokens};
8use syn::spanned::Spanned;
9
10#[proc_macro_derive(ReflectMessage, attributes(prost_reflect))]
14pub fn reflect_message(input: TokenStream) -> TokenStream {
15 let input = syn::parse_macro_input!(input as syn::DeriveInput);
16
17 match reflect_message_impl(input) {
18 Ok(tokens) => tokens.into(),
19 Err(err) => err.to_compile_error().into(),
20 }
21}
22
23struct Args {
24 args_span: Span,
25 message_name: Option<syn::Lit>,
26 descriptor_pool: Option<syn::LitStr>,
27 file_descriptor_set: Option<syn::LitStr>,
28}
29
30fn reflect_message_impl(input: syn::DeriveInput) -> Result<proc_macro2::TokenStream, syn::Error> {
31 match &input.data {
32 syn::Data::Struct(_) => (),
33 syn::Data::Enum(_) => return Ok(Default::default()),
34 syn::Data::Union(_) => return Ok(Default::default()),
35 };
36
37 let args = Args::parse(input.ident.span(), &input.attrs)?;
38
39 let name = &input.ident;
40 let descriptor_pool = args.descriptor_pool()?;
41 let message_name = args.message_name()?;
42
43 Ok(quote! {
44 impl ::prost_reflect::ReflectMessage for #name {
45 fn descriptor(&self) -> ::prost_reflect::MessageDescriptor {
46 #descriptor_pool
47 .get_message_by_name(#message_name)
48 .expect(concat!("descriptor for message type `", #message_name, "` not found"))
49 }
50 }
51 })
52}
53
54fn is_prost_reflect_attribute(attr: &syn::Attribute) -> bool {
55 attr.path().is_ident("prost_reflect")
56}
57
58impl Args {
59 fn parse(input_span: proc_macro2::Span, attrs: &[syn::Attribute]) -> Result<Args, syn::Error> {
60 let reflect_attrs: Vec<_> = attrs
61 .iter()
62 .filter(|attr| is_prost_reflect_attribute(attr))
63 .collect();
64
65 if reflect_attrs.is_empty() {
66 return Err(syn::Error::new(
67 input_span,
68 "missing #[prost_reflect] attribute",
69 ));
70 }
71
72 let mut args = Args {
73 args_span: reflect_attrs
74 .iter()
75 .map(|a| a.span())
76 .reduce(|l, r| l.join(r).unwrap_or(l))
77 .unwrap(),
78 message_name: None,
79 descriptor_pool: None,
80 file_descriptor_set: None,
81 };
82
83 for attr in reflect_attrs {
84 attr.parse_nested_meta(|nested| {
85 if nested.path.is_ident("descriptor_pool") {
86 args.descriptor_pool = nested.value()?.parse()?;
87 Ok(())
88 } else if nested.path.is_ident("file_descriptor_set_bytes") {
89 args.file_descriptor_set = nested.value()?.parse()?;
90 Ok(())
91 } else if nested.path.is_ident("message_name") {
92 args.message_name = nested.value()?.parse()?;
93 Ok(())
94 } else {
95 Err(syn::Error::new(
96 nested.path.span(),
97 "unknown argument (expected 'descriptor_pool', 'file_descriptor_set_bytes' or 'message_name')",
98 ))
99 }
100 })?;
101 }
102
103 Ok(args)
104 }
105
106 fn descriptor_pool(&self) -> Result<proc_macro2::TokenStream, syn::Error> {
107 if let Some(descriptor_pool) = &self.descriptor_pool {
108 let expr: syn::Expr = syn::parse_str(&descriptor_pool.value())?;
109 Ok(expr.to_token_stream())
110 } else if let Some(file_descriptor_set) = &self.file_descriptor_set {
111 let expr: syn::Expr = syn::parse_str(&file_descriptor_set.value())?;
112
113 Ok(quote!({
114 static INIT: ::std::sync::Once = ::std::sync::Once::new();
115 INIT.call_once(|| ::prost_reflect::DescriptorPool::decode_global_file_descriptor_set(#expr).unwrap());
116 ::prost_reflect::DescriptorPool::global()
117 }))
118 } else {
119 Err(syn::Error::new(
120 self.args_span,
121 "missing required argument 'descriptor_pool'",
122 ))
123 }
124 }
125
126 fn message_name(&self) -> Result<proc_macro2::TokenStream, syn::Error> {
127 if let Some(message_name) = &self.message_name {
128 Ok(message_name.to_token_stream())
129 } else {
130 Err(syn::Error::new(
131 self.args_span,
132 "missing required argument 'message_name'",
133 ))
134 }
135 }
136}