prost_reflect_derive/
lib.rs

1//! This crate provides the [`ReflectMessage`](https://docs.rs/prost-reflect/latest/prost_reflect/derive.ReflectMessage.html) derive macro
2//!
3//! For documentation, see the example in the [`prost-reflect` crate docs](https://docs.rs/prost-reflect/latest/prost_reflect/index.html#deriving-reflectmessage).
4
5use proc_macro::TokenStream;
6use proc_macro2::Span;
7use quote::{quote, ToTokens};
8use syn::spanned::Spanned;
9
10/// A derive macro for the [`ReflectMessage`](https://docs.rs/prost-reflect/latest/prost_reflect/trait.ReflectMessage.html) trait.
11///
12/// For documentation, see the example in the [`prost-reflect` crate docs](https://docs.rs/prost-reflect/latest/prost_reflect/index.html#deriving-reflectmessage).
13#[proc_macro_derive(ReflectMessage, attributes(prost_reflect))]
14pub fn reflect_message(input: TokenStream) -> TokenStream {
15    let input = syn::parse_macro_input!(input as syn::DeriveInput);
16
17    match reflect_message_impl(input) {
18        Ok(tokens) => tokens.into(),
19        Err(err) => err.to_compile_error().into(),
20    }
21}
22
23struct Args {
24    args_span: Span,
25    message_name: Option<syn::Lit>,
26    descriptor_pool: Option<syn::LitStr>,
27    file_descriptor_set: Option<syn::LitStr>,
28}
29
30fn reflect_message_impl(input: syn::DeriveInput) -> Result<proc_macro2::TokenStream, syn::Error> {
31    match &input.data {
32        syn::Data::Struct(_) => (),
33        syn::Data::Enum(_) => return Ok(Default::default()),
34        syn::Data::Union(_) => return Ok(Default::default()),
35    };
36
37    let args = Args::parse(input.ident.span(), &input.attrs)?;
38
39    let name = &input.ident;
40    let descriptor_pool = args.descriptor_pool()?;
41    let message_name = args.message_name()?;
42
43    Ok(quote! {
44        impl ::prost_reflect::ReflectMessage for #name {
45            fn descriptor(&self) -> ::prost_reflect::MessageDescriptor {
46                #descriptor_pool
47                    .get_message_by_name(#message_name)
48                    .expect(concat!("descriptor for message type `", #message_name, "` not found"))
49            }
50        }
51    })
52}
53
54fn is_prost_reflect_attribute(attr: &syn::Attribute) -> bool {
55    attr.path().is_ident("prost_reflect")
56}
57
58impl Args {
59    fn parse(input_span: proc_macro2::Span, attrs: &[syn::Attribute]) -> Result<Args, syn::Error> {
60        let reflect_attrs: Vec<_> = attrs
61            .iter()
62            .filter(|attr| is_prost_reflect_attribute(attr))
63            .collect();
64
65        if reflect_attrs.is_empty() {
66            return Err(syn::Error::new(
67                input_span,
68                "missing #[prost_reflect] attribute",
69            ));
70        }
71
72        let mut args = Args {
73            args_span: reflect_attrs
74                .iter()
75                .map(|a| a.span())
76                .reduce(|l, r| l.join(r).unwrap_or(l))
77                .unwrap(),
78            message_name: None,
79            descriptor_pool: None,
80            file_descriptor_set: None,
81        };
82
83        for attr in reflect_attrs {
84            attr.parse_nested_meta(|nested| {
85                if nested.path.is_ident("descriptor_pool") {
86                    args.descriptor_pool = nested.value()?.parse()?;
87                    Ok(())
88                } else if nested.path.is_ident("file_descriptor_set_bytes") {
89                    args.file_descriptor_set = nested.value()?.parse()?;
90                    Ok(())
91                } else if nested.path.is_ident("message_name") {
92                    args.message_name = nested.value()?.parse()?;
93                    Ok(())
94                } else {
95                    Err(syn::Error::new(
96                        nested.path.span(),
97                        "unknown argument (expected 'descriptor_pool', 'file_descriptor_set_bytes' or 'message_name')",
98                    ))
99                }
100            })?;
101        }
102
103        Ok(args)
104    }
105
106    fn descriptor_pool(&self) -> Result<proc_macro2::TokenStream, syn::Error> {
107        if let Some(descriptor_pool) = &self.descriptor_pool {
108            let expr: syn::Expr = syn::parse_str(&descriptor_pool.value())?;
109            Ok(expr.to_token_stream())
110        } else if let Some(file_descriptor_set) = &self.file_descriptor_set {
111            let expr: syn::Expr = syn::parse_str(&file_descriptor_set.value())?;
112
113            Ok(quote!({
114                static INIT: ::std::sync::Once = ::std::sync::Once::new();
115                INIT.call_once(|| ::prost_reflect::DescriptorPool::decode_global_file_descriptor_set(#expr).unwrap());
116                ::prost_reflect::DescriptorPool::global()
117            }))
118        } else {
119            Err(syn::Error::new(
120                self.args_span,
121                "missing required argument 'descriptor_pool'",
122            ))
123        }
124    }
125
126    fn message_name(&self) -> Result<proc_macro2::TokenStream, syn::Error> {
127        if let Some(message_name) = &self.message_name {
128            Ok(message_name.to_token_stream())
129        } else {
130            Err(syn::Error::new(
131                self.args_span,
132                "missing required argument 'message_name'",
133            ))
134        }
135    }
136}