ntex_prost_derive/
lib.rs

1#![doc(html_root_url = "https://docs.rs/ntex-prost-derive/0.10.3")]
2// The `quote!` macro requires deep recursion.
3#![recursion_limit = "4096"]
4
5extern crate alloc;
6extern crate proc_macro;
7
8use anyhow::{bail, Error};
9use itertools::Itertools;
10use proc_macro::TokenStream;
11use proc_macro2::Span;
12use quote::quote;
13use syn::{
14    punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
15    FieldsUnnamed, Ident, Variant,
16};
17
18mod field;
19mod server;
20
21use crate::field::Field;
22
23#[proc_macro_derive(Message, attributes(prost))]
24pub fn message(input: TokenStream) -> TokenStream {
25    try_message(input).unwrap()
26}
27
28#[proc_macro_derive(Enumeration, attributes(prost))]
29pub fn enumeration(input: TokenStream) -> TokenStream {
30    try_enumeration(input).unwrap()
31}
32
33#[proc_macro_derive(Oneof, attributes(prost))]
34pub fn oneof(input: TokenStream) -> TokenStream {
35    try_oneof(input).unwrap()
36}
37
38#[proc_macro_attribute]
39pub fn server(attr: TokenStream, item: TokenStream) -> TokenStream {
40    server::server(attr, item)
41}
42
43fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
44    let input: DeriveInput = syn::parse(input)?;
45
46    let ident = input.ident;
47
48    let variant_data = match input.data {
49        Data::Struct(variant_data) => variant_data,
50        Data::Enum(..) => bail!("Message can not be derived for an enum"),
51        Data::Union(..) => bail!("Message can not be derived for a union"),
52    };
53
54    let generics = &input.generics;
55    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
56
57    let fields = match variant_data {
58        DataStruct {
59            fields: Fields::Named(FieldsNamed { named: fields, .. }),
60            ..
61        }
62        | DataStruct {
63            fields:
64                Fields::Unnamed(FieldsUnnamed {
65                    unnamed: fields, ..
66                }),
67            ..
68        } => fields.into_iter().collect(),
69        DataStruct {
70            fields: Fields::Unit,
71            ..
72        } => Vec::new(),
73    };
74
75    let mut next_tag: u32 = 1;
76    let mut fields = fields
77        .into_iter()
78        .enumerate()
79        .flat_map(|(idx, field)| {
80            let field_ident = field
81                .ident
82                .unwrap_or_else(|| Ident::new(&idx.to_string(), Span::call_site()));
83            match Field::new(field.attrs, Some(next_tag)) {
84                Ok(Some(field)) => {
85                    next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag);
86                    Some(Ok((field_ident, field)))
87                }
88                Ok(None) => None,
89                Err(err) => Some(Err(
90                    err.context(format!("invalid message field {}.{}", ident, field_ident))
91                )),
92            }
93        })
94        .collect::<Result<Vec<_>, _>>()?;
95
96    // We want Debug to be in declaration order
97    let unsorted_fields = fields.clone();
98
99    // Sort the fields by tag number so that fields will be encoded in tag order.
100    // TODO: This encodes oneof fields in the position of their lowest tag,
101    // regardless of the currently occupied variant, is that consequential?
102    // See: https://developers.google.com/protocol-buffers/docs/encoding#order
103    fields.sort_by_key(|&(_, ref field)| field.tags().into_iter().min().unwrap());
104    let fields = fields;
105
106    let mut tags = fields
107        .iter()
108        .flat_map(|&(_, ref field)| field.tags())
109        .collect::<Vec<_>>();
110    let num_tags = tags.len();
111    tags.sort_unstable();
112    tags.dedup();
113    if tags.len() != num_tags {
114        bail!("message {} has fields with duplicate tags", ident);
115    }
116
117    let encoded_len = fields
118        .iter()
119        .map(|&(ref field_ident, ref field)| field.encoded_len(quote!(self.#field_ident)));
120
121    let encode = fields
122        .iter()
123        .map(|&(ref field_ident, ref field)| field.encode(quote!(self.#field_ident)));
124
125    let merge = fields.iter().map(|&(ref field_ident, ref field)| {
126        let tags = field.tags().into_iter().map(|tag| quote!(#tag));
127        let tags = Itertools::intersperse(tags, quote!(|));
128
129        if field.is_oneof() {
130            quote! {
131                #(#tags)* => OneofType::merge(&mut msg.#field_ident, tag, wire_type, buf)
132                .map_err(|err| err.push(STRUCT_NAME, stringify!(#field_ident)))?,
133            }
134        } else {
135            quote! {
136                #(#tags)* => NativeType::deserialize(&mut msg.#field_ident, wire_type, buf)
137                .map_err(|err| err.push(STRUCT_NAME, stringify!(#field_ident)))?,
138            }
139        }
140    });
141
142    let struct_name = if fields.is_empty() {
143        quote!()
144    } else {
145        quote!(
146            const STRUCT_NAME: &'static str = stringify!(#ident);
147        )
148    };
149
150    let default = fields.iter().map(|&(ref field_ident, ref field)| {
151        let value = field.default();
152        quote!(#field_ident: #value,)
153    });
154
155    let debugs = unsorted_fields.iter().map(|&(ref field_ident, _)| {
156        quote!(builder.field(stringify!(#field_ident), &self.#field_ident))
157    });
158    let debug_builder = quote!(f.debug_struct(stringify!(#ident)));
159
160    let expanded = quote! {
161        #[allow(unused_variables)]
162        impl ::ntex_grpc::Message for #ident #ty_generics #where_clause {
163            fn write(&self, buf: &mut ::ntex_grpc::types::BytesMut) {
164                use ::ntex_grpc::{NativeType, types::OneofType};
165
166                #(#encode)*
167            }
168
169            fn read(buf: &mut ::ntex_grpc::types::Bytes) -> ::std::result::Result<Self, ::ntex_grpc::DecodeError> {
170                use ::ntex_grpc::{NativeType, types::OneofType};
171
172                #struct_name
173
174                let mut msg = Self::default();
175
176                while !buf.is_empty() {
177                    let (tag, wire_type) = ::ntex_grpc::encoding::decode_key(buf)?;
178
179                    match tag {
180                        #(#merge)*
181                        _ => ::ntex_grpc::encoding::skip_field(wire_type, tag, buf)?,
182                    }
183                }
184
185                Ok(msg)
186            }
187
188            #[inline]
189            fn encoded_len(&self) -> usize {
190                use ::ntex_grpc::{NativeType, types::OneofType};
191
192                0 #(+ #encoded_len)*
193            }
194        }
195
196        impl #impl_generics ::std::default::Default for #ident #ty_generics #where_clause {
197            fn default() -> Self {
198                #ident {
199                    #(#default)*
200                }
201            }
202        }
203
204        impl #impl_generics ::std::fmt::Debug for #ident #ty_generics #where_clause {
205            fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
206                let mut builder = #debug_builder;
207                #(#debugs;)*
208                builder.finish()
209            }
210        }
211    };
212
213    Ok(expanded.into())
214}
215
216fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
217    let input: DeriveInput = syn::parse(input)?;
218    let ident = input.ident;
219
220    let generics = &input.generics;
221    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
222
223    let punctuated_variants = match input.data {
224        Data::Enum(DataEnum { variants, .. }) => variants,
225        Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
226        Data::Union(..) => bail!("Enumeration can not be derived for a union"),
227    };
228
229    // Map the variants into 'fields'.
230    let mut variants: Vec<(Ident, Expr)> = Vec::new();
231    for Variant {
232        ident,
233        fields,
234        discriminant,
235        ..
236    } in punctuated_variants
237    {
238        match fields {
239            Fields::Unit => (),
240            Fields::Named(_) | Fields::Unnamed(_) => {
241                bail!("Enumeration variants may not have fields")
242            }
243        }
244
245        match discriminant {
246            Some((_, expr)) => variants.push((ident, expr)),
247            None => bail!("Enumeration variants must have a disriminant"),
248        }
249    }
250
251    if variants.is_empty() {
252        panic!("Enumeration must have at least one variant");
253    }
254
255    let default = variants[0].0.clone();
256
257    let is_valid = variants
258        .iter()
259        .map(|&(_, ref value)| quote!(#value => true));
260    let from = variants.iter().map(
261        |&(ref variant, ref value)| quote!(#value => ::std::option::Option::Some(#ident::#variant)),
262    );
263
264    let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident);
265    let from_i32_doc = format!(
266        "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.",
267        ident
268    );
269
270    let expanded = quote! {
271        impl #impl_generics #ident #ty_generics #where_clause {
272            #[doc=#is_valid_doc]
273            pub fn is_valid(value: i32) -> bool {
274                match value {
275                    #(#is_valid,)*
276                    _ => false,
277                }
278            }
279
280            #[doc=#from_i32_doc]
281            pub fn from_i32(value: i32) -> ::std::option::Option<#ident> {
282                match value {
283                    #(#from,)*
284                    _ => ::std::option::Option::None,
285                }
286            }
287        }
288
289        impl #impl_generics ::std::default::Default for #ident #ty_generics #where_clause {
290            fn default() -> #ident {
291                #ident::#default
292            }
293        }
294
295        impl #impl_generics ::std::convert::From::<#ident> for i32 #ty_generics #where_clause {
296            fn from(value: #ident) -> i32 {
297                value as i32
298            }
299        }
300    };
301
302    Ok(expanded.into())
303}
304
305fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
306    let input: DeriveInput = syn::parse(input)?;
307
308    let ident = input.ident;
309
310    let variants = match input.data {
311        Data::Enum(DataEnum { variants, .. }) => variants,
312        Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
313        Data::Union(..) => bail!("Oneof can not be derived for a union"),
314    };
315
316    let generics = &input.generics;
317    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
318
319    // Map the variants into 'fields'.
320    let mut fields: Vec<(Ident, Field)> = Vec::new();
321    for Variant {
322        attrs,
323        ident: variant_ident,
324        fields: variant_fields,
325        ..
326    } in variants
327    {
328        let variant_fields = match variant_fields {
329            Fields::Unit => Punctuated::new(),
330            Fields::Named(FieldsNamed { named: fields, .. })
331            | Fields::Unnamed(FieldsUnnamed {
332                unnamed: fields, ..
333            }) => fields,
334        };
335        if variant_fields.len() != 1 {
336            bail!("Oneof enum variants must have a single field");
337        }
338        match Field::new_oneof(attrs)? {
339            Some(field) => fields.push((variant_ident, field)),
340            None => bail!("invalid oneof variant: oneof variants may not be ignored"),
341        }
342    }
343
344    let mut tags = fields
345        .iter()
346        .flat_map(|&(ref variant_ident, ref field)| -> Result<u32, Error> {
347            if field.tags().len() > 1 {
348                bail!(
349                    "invalid oneof variant {}::{}: oneof variants may only have a single tag",
350                    ident,
351                    variant_ident
352                );
353            }
354            Ok(field.tags()[0])
355        })
356        .collect::<Vec<_>>();
357    tags.sort_unstable();
358    tags.dedup();
359    if tags.len() != fields.len() {
360        panic!("invalid oneof {}: variants have duplicate tags", ident);
361    }
362
363    let encode = fields.iter().map(|&(ref variant_ident, ref field)| {
364        let encode = field.encode(quote!(*value));
365        quote!(#ident::#variant_ident(ref value) => { #encode })
366    });
367
368    let merge = fields.iter().map(|&(ref variant_ident, ref field)| {
369        let tag = field.tags()[0];
370        quote! {
371            #tag => {
372                #ident::#variant_ident(NativeType::deserialize_default(wire_type, buf)?)
373            }
374        }
375    });
376
377    let encoded_len = fields.iter().map(|&(ref variant_ident, ref field)| {
378        let encoded_len = field.encoded_len(quote!(*value));
379        quote!(#ident::#variant_ident(ref value) => #encoded_len)
380    });
381
382    let debug = fields.iter().map(|&(ref variant_ident, _)| {
383        quote!(#ident::#variant_ident(ref value) => {
384            f.debug_tuple(stringify!(#variant_ident))
385                .field(value)
386                .finish()
387        })
388    });
389
390    let expanded = quote! {
391        impl ::ntex_grpc::types::OneofType for #impl_generics #ident #ty_generics #where_clause {
392            #[inline]
393            /// Encodes the message to a buffer.
394            fn encode(&self, buf: &mut ::ntex_grpc::types::BytesMut) {
395                use ::ntex_grpc::NativeType;
396
397                match *self {
398                    #(#encode,)*
399                }
400            }
401
402            #[inline]
403            /// Decodes an instance of the message from a buffer, and merges it into self.
404            fn decode(tag: u32, wire_type: ::ntex_grpc::types::WireType, buf: &mut ::ntex_grpc::types::Bytes) -> ::std::result::Result<Self, ::ntex_grpc::DecodeError> {
405                use ::ntex_grpc::NativeType;
406
407                Ok(match tag {
408                    #(#merge,)*
409                    _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag),
410                })
411            }
412
413            /// Returns the encoded length of the message without a length delimiter.
414            #[inline]
415            fn encoded_len(&self) -> usize {
416                use ::ntex_grpc::NativeType;
417
418                match *self {
419                    #(#encoded_len,)*
420                }
421            }
422        }
423
424        impl #impl_generics ::std::fmt::Debug for #ident #ty_generics #where_clause {
425            fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
426                match *self {
427                    #(#debug,)*
428                }
429            }
430        }
431    };
432
433    Ok(expanded.into())
434}