1#![doc(html_root_url = "https://docs.rs/ntex-prost-derive/0.10.3")]
2#![recursion_limit = "4096"]
4
5extern crate alloc;
6extern crate proc_macro;
7
8use anyhow::{bail, Error};
9use itertools::Itertools;
10use proc_macro::TokenStream;
11use proc_macro2::Span;
12use quote::quote;
13use syn::{
14 punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
15 FieldsUnnamed, Ident, Variant,
16};
17
18mod field;
19mod server;
20
21use crate::field::Field;
22
23#[proc_macro_derive(Message, attributes(prost))]
24pub fn message(input: TokenStream) -> TokenStream {
25 try_message(input).unwrap()
26}
27
28#[proc_macro_derive(Enumeration, attributes(prost))]
29pub fn enumeration(input: TokenStream) -> TokenStream {
30 try_enumeration(input).unwrap()
31}
32
33#[proc_macro_derive(Oneof, attributes(prost))]
34pub fn oneof(input: TokenStream) -> TokenStream {
35 try_oneof(input).unwrap()
36}
37
38#[proc_macro_attribute]
39pub fn server(attr: TokenStream, item: TokenStream) -> TokenStream {
40 server::server(attr, item)
41}
42
43fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
44 let input: DeriveInput = syn::parse(input)?;
45
46 let ident = input.ident;
47
48 let variant_data = match input.data {
49 Data::Struct(variant_data) => variant_data,
50 Data::Enum(..) => bail!("Message can not be derived for an enum"),
51 Data::Union(..) => bail!("Message can not be derived for a union"),
52 };
53
54 let generics = &input.generics;
55 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
56
57 let fields = match variant_data {
58 DataStruct {
59 fields: Fields::Named(FieldsNamed { named: fields, .. }),
60 ..
61 }
62 | DataStruct {
63 fields:
64 Fields::Unnamed(FieldsUnnamed {
65 unnamed: fields, ..
66 }),
67 ..
68 } => fields.into_iter().collect(),
69 DataStruct {
70 fields: Fields::Unit,
71 ..
72 } => Vec::new(),
73 };
74
75 let mut next_tag: u32 = 1;
76 let mut fields = fields
77 .into_iter()
78 .enumerate()
79 .flat_map(|(idx, field)| {
80 let field_ident = field
81 .ident
82 .unwrap_or_else(|| Ident::new(&idx.to_string(), Span::call_site()));
83 match Field::new(field.attrs, Some(next_tag)) {
84 Ok(Some(field)) => {
85 next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag);
86 Some(Ok((field_ident, field)))
87 }
88 Ok(None) => None,
89 Err(err) => Some(Err(
90 err.context(format!("invalid message field {}.{}", ident, field_ident))
91 )),
92 }
93 })
94 .collect::<Result<Vec<_>, _>>()?;
95
96 let unsorted_fields = fields.clone();
98
99 fields.sort_by_key(|&(_, ref field)| field.tags().into_iter().min().unwrap());
104 let fields = fields;
105
106 let mut tags = fields
107 .iter()
108 .flat_map(|&(_, ref field)| field.tags())
109 .collect::<Vec<_>>();
110 let num_tags = tags.len();
111 tags.sort_unstable();
112 tags.dedup();
113 if tags.len() != num_tags {
114 bail!("message {} has fields with duplicate tags", ident);
115 }
116
117 let encoded_len = fields
118 .iter()
119 .map(|&(ref field_ident, ref field)| field.encoded_len(quote!(self.#field_ident)));
120
121 let encode = fields
122 .iter()
123 .map(|&(ref field_ident, ref field)| field.encode(quote!(self.#field_ident)));
124
125 let merge = fields.iter().map(|&(ref field_ident, ref field)| {
126 let tags = field.tags().into_iter().map(|tag| quote!(#tag));
127 let tags = Itertools::intersperse(tags, quote!(|));
128
129 if field.is_oneof() {
130 quote! {
131 #(#tags)* => OneofType::merge(&mut msg.#field_ident, tag, wire_type, buf)
132 .map_err(|err| err.push(STRUCT_NAME, stringify!(#field_ident)))?,
133 }
134 } else {
135 quote! {
136 #(#tags)* => NativeType::deserialize(&mut msg.#field_ident, wire_type, buf)
137 .map_err(|err| err.push(STRUCT_NAME, stringify!(#field_ident)))?,
138 }
139 }
140 });
141
142 let struct_name = if fields.is_empty() {
143 quote!()
144 } else {
145 quote!(
146 const STRUCT_NAME: &'static str = stringify!(#ident);
147 )
148 };
149
150 let default = fields.iter().map(|&(ref field_ident, ref field)| {
151 let value = field.default();
152 quote!(#field_ident: #value,)
153 });
154
155 let debugs = unsorted_fields.iter().map(|&(ref field_ident, _)| {
156 quote!(builder.field(stringify!(#field_ident), &self.#field_ident))
157 });
158 let debug_builder = quote!(f.debug_struct(stringify!(#ident)));
159
160 let expanded = quote! {
161 #[allow(unused_variables)]
162 impl ::ntex_grpc::Message for #ident #ty_generics #where_clause {
163 fn write(&self, buf: &mut ::ntex_grpc::types::BytesMut) {
164 use ::ntex_grpc::{NativeType, types::OneofType};
165
166 #(#encode)*
167 }
168
169 fn read(buf: &mut ::ntex_grpc::types::Bytes) -> ::std::result::Result<Self, ::ntex_grpc::DecodeError> {
170 use ::ntex_grpc::{NativeType, types::OneofType};
171
172 #struct_name
173
174 let mut msg = Self::default();
175
176 while !buf.is_empty() {
177 let (tag, wire_type) = ::ntex_grpc::encoding::decode_key(buf)?;
178
179 match tag {
180 #(#merge)*
181 _ => ::ntex_grpc::encoding::skip_field(wire_type, tag, buf)?,
182 }
183 }
184
185 Ok(msg)
186 }
187
188 #[inline]
189 fn encoded_len(&self) -> usize {
190 use ::ntex_grpc::{NativeType, types::OneofType};
191
192 0 #(+ #encoded_len)*
193 }
194 }
195
196 impl #impl_generics ::std::default::Default for #ident #ty_generics #where_clause {
197 fn default() -> Self {
198 #ident {
199 #(#default)*
200 }
201 }
202 }
203
204 impl #impl_generics ::std::fmt::Debug for #ident #ty_generics #where_clause {
205 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
206 let mut builder = #debug_builder;
207 #(#debugs;)*
208 builder.finish()
209 }
210 }
211 };
212
213 Ok(expanded.into())
214}
215
216fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
217 let input: DeriveInput = syn::parse(input)?;
218 let ident = input.ident;
219
220 let generics = &input.generics;
221 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
222
223 let punctuated_variants = match input.data {
224 Data::Enum(DataEnum { variants, .. }) => variants,
225 Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
226 Data::Union(..) => bail!("Enumeration can not be derived for a union"),
227 };
228
229 let mut variants: Vec<(Ident, Expr)> = Vec::new();
231 for Variant {
232 ident,
233 fields,
234 discriminant,
235 ..
236 } in punctuated_variants
237 {
238 match fields {
239 Fields::Unit => (),
240 Fields::Named(_) | Fields::Unnamed(_) => {
241 bail!("Enumeration variants may not have fields")
242 }
243 }
244
245 match discriminant {
246 Some((_, expr)) => variants.push((ident, expr)),
247 None => bail!("Enumeration variants must have a disriminant"),
248 }
249 }
250
251 if variants.is_empty() {
252 panic!("Enumeration must have at least one variant");
253 }
254
255 let default = variants[0].0.clone();
256
257 let is_valid = variants
258 .iter()
259 .map(|&(_, ref value)| quote!(#value => true));
260 let from = variants.iter().map(
261 |&(ref variant, ref value)| quote!(#value => ::std::option::Option::Some(#ident::#variant)),
262 );
263
264 let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident);
265 let from_i32_doc = format!(
266 "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.",
267 ident
268 );
269
270 let expanded = quote! {
271 impl #impl_generics #ident #ty_generics #where_clause {
272 #[doc=#is_valid_doc]
273 pub fn is_valid(value: i32) -> bool {
274 match value {
275 #(#is_valid,)*
276 _ => false,
277 }
278 }
279
280 #[doc=#from_i32_doc]
281 pub fn from_i32(value: i32) -> ::std::option::Option<#ident> {
282 match value {
283 #(#from,)*
284 _ => ::std::option::Option::None,
285 }
286 }
287 }
288
289 impl #impl_generics ::std::default::Default for #ident #ty_generics #where_clause {
290 fn default() -> #ident {
291 #ident::#default
292 }
293 }
294
295 impl #impl_generics ::std::convert::From::<#ident> for i32 #ty_generics #where_clause {
296 fn from(value: #ident) -> i32 {
297 value as i32
298 }
299 }
300 };
301
302 Ok(expanded.into())
303}
304
305fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
306 let input: DeriveInput = syn::parse(input)?;
307
308 let ident = input.ident;
309
310 let variants = match input.data {
311 Data::Enum(DataEnum { variants, .. }) => variants,
312 Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
313 Data::Union(..) => bail!("Oneof can not be derived for a union"),
314 };
315
316 let generics = &input.generics;
317 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
318
319 let mut fields: Vec<(Ident, Field)> = Vec::new();
321 for Variant {
322 attrs,
323 ident: variant_ident,
324 fields: variant_fields,
325 ..
326 } in variants
327 {
328 let variant_fields = match variant_fields {
329 Fields::Unit => Punctuated::new(),
330 Fields::Named(FieldsNamed { named: fields, .. })
331 | Fields::Unnamed(FieldsUnnamed {
332 unnamed: fields, ..
333 }) => fields,
334 };
335 if variant_fields.len() != 1 {
336 bail!("Oneof enum variants must have a single field");
337 }
338 match Field::new_oneof(attrs)? {
339 Some(field) => fields.push((variant_ident, field)),
340 None => bail!("invalid oneof variant: oneof variants may not be ignored"),
341 }
342 }
343
344 let mut tags = fields
345 .iter()
346 .flat_map(|&(ref variant_ident, ref field)| -> Result<u32, Error> {
347 if field.tags().len() > 1 {
348 bail!(
349 "invalid oneof variant {}::{}: oneof variants may only have a single tag",
350 ident,
351 variant_ident
352 );
353 }
354 Ok(field.tags()[0])
355 })
356 .collect::<Vec<_>>();
357 tags.sort_unstable();
358 tags.dedup();
359 if tags.len() != fields.len() {
360 panic!("invalid oneof {}: variants have duplicate tags", ident);
361 }
362
363 let encode = fields.iter().map(|&(ref variant_ident, ref field)| {
364 let encode = field.encode(quote!(*value));
365 quote!(#ident::#variant_ident(ref value) => { #encode })
366 });
367
368 let merge = fields.iter().map(|&(ref variant_ident, ref field)| {
369 let tag = field.tags()[0];
370 quote! {
371 #tag => {
372 #ident::#variant_ident(NativeType::deserialize_default(wire_type, buf)?)
373 }
374 }
375 });
376
377 let encoded_len = fields.iter().map(|&(ref variant_ident, ref field)| {
378 let encoded_len = field.encoded_len(quote!(*value));
379 quote!(#ident::#variant_ident(ref value) => #encoded_len)
380 });
381
382 let debug = fields.iter().map(|&(ref variant_ident, _)| {
383 quote!(#ident::#variant_ident(ref value) => {
384 f.debug_tuple(stringify!(#variant_ident))
385 .field(value)
386 .finish()
387 })
388 });
389
390 let expanded = quote! {
391 impl ::ntex_grpc::types::OneofType for #impl_generics #ident #ty_generics #where_clause {
392 #[inline]
393 fn encode(&self, buf: &mut ::ntex_grpc::types::BytesMut) {
395 use ::ntex_grpc::NativeType;
396
397 match *self {
398 #(#encode,)*
399 }
400 }
401
402 #[inline]
403 fn decode(tag: u32, wire_type: ::ntex_grpc::types::WireType, buf: &mut ::ntex_grpc::types::Bytes) -> ::std::result::Result<Self, ::ntex_grpc::DecodeError> {
405 use ::ntex_grpc::NativeType;
406
407 Ok(match tag {
408 #(#merge,)*
409 _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag),
410 })
411 }
412
413 #[inline]
415 fn encoded_len(&self) -> usize {
416 use ::ntex_grpc::NativeType;
417
418 match *self {
419 #(#encoded_len,)*
420 }
421 }
422 }
423
424 impl #impl_generics ::std::fmt::Debug for #ident #ty_generics #where_clause {
425 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
426 match *self {
427 #(#debug,)*
428 }
429 }
430 }
431 };
432
433 Ok(expanded.into())
434}