proto_json/lib.rs
1extern crate proc_macro;
2extern crate syn;
3use quote::quote;
4use syn::{
5 fold::{fold_type, Fold},
6 parse_macro_input, parse_quote,
7 punctuated::Punctuated,
8 visit::Visit,
9 Attribute, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, LitStr, Path, PathSegment,
10 Type, TypePath,
11};
12
13/// Helps to glue together json and protobufs when placed on a valid prost::Message, prost::Enumeration, or prost::Oneof
14/// For structs, it walks the fields checking whether they are enums then adds serialize_with and deserialize_with attributes to relevant fields.
15/// Structs need to have the prost::Message attribute while Enums require the prost::Enumeration attribute
16/// For enums, it checks that the provided string is a valid variant then deserializes it as i32 to match the protobuf definitions
17/// # Example
18/// ```
19/// # #[macros::proto_json]
20/// pub struct Address {
21/// country : String,
22/// city : String,
23/// state : Option<String>,
24/// street : String,
25/// line1 : String,
26/// line2 : Option<String>
27/// # }
28/// ```
29/// Example with enums
30/// ```
31/// # #[macros::proto_json]
32/// pub enum Currency {
33/// USD = 0;
34/// GPB = 1;
35/// JPY = 2;
36/// # }
37/// ```
38#[proc_macro_attribute]
39pub fn proto_json(
40 _attr: proc_macro::TokenStream,
41 input: proc_macro::TokenStream,
42) -> ::proc_macro::TokenStream {
43 let mut ast = parse_macro_input!(input as DeriveInput);
44
45 let ident = &ast.ident;
46
47 let mut is_prost_message = false;
48 let mut is_prost_enumeration = false;
49 let mut is_prost_one_of = false;
50
51 // If item does not implement one of the following attributes, then return error as it is not a valid protobuf object.
52 for attrib in ast.attrs.iter() {
53 if attrib.path().is_ident("derive") {
54 attrib
55 .parse_nested_meta(|meta| {
56 match meta.path.leading_colon {
57 Some(_) => match meta
58 .path
59 .segments
60 .last()
61 .unwrap()
62 .ident
63 .to_string()
64 .as_str()
65 {
66 "Message" => is_prost_message = true,
67 "Enumeration" => is_prost_enumeration = true,
68 "Oneof" => is_prost_one_of = true,
69 _ => (),
70 },
71 None => (),
72 }
73 Ok(())
74 })
75 .unwrap();
76 }
77 }
78
79 let generated: proc_macro2::TokenStream = match ast.data {
80 Data::Enum(ref mut de) => {
81 match is_prost_enumeration {
82 // Implement the str_to_i32 and i32_to_str methods
83 true => {
84 // Iterate over enum variants
85 let variants = de.variants.iter().map(|v| &v.ident);
86
87 // Convert variant name to snake_case
88 let variant_str_as_i32 = variants.clone().map(|variant| {
89 let variant_str = &::convert_case::Casing::to_case(
90 &variant.to_string(),
91 ::convert_case::Case::Snake,
92 );
93 // usd => Ok(Currency::Usd as i32)
94 quote! (#variant_str => Ok(#ident::#variant as i32))
95 });
96
97 // Creates a list of the enum's fields. Used to give useful error messages when deserializing.
98 let expected_fields = variants
99 .clone()
100 .map(|variant| {
101 let variant_str = ::convert_case::Casing::to_case(
102 &variant.to_string(),
103 ::convert_case::Case::Snake,
104 );
105 variant_str
106 })
107 .into_iter()
108 .map(|x| x)
109 .collect::<Vec<String>>()
110 .join(",");
111
112 let serde_funcs = quote! {
113 /// Methods for converting from Protofbuf to and from Json enums
114 impl #ident {
115 /// Deserialize enum from string to protobuf i32
116 pub fn str_to_i32<'de, D>(deserializer: D) -> core::result::Result<i32, D::Error>
117 where
118 D: serde::de::Deserializer<'de>,
119 {
120 let s: &str = serde::de::Deserialize::deserialize(deserializer)?;
121
122 match s.to_lowercase().as_str() {
123 #(#variant_str_as_i32,)*
124 _ => core::result::Result::Err(serde::de::Error::unknown_variant(s, &[#expected_fields])),
125 }
126 }
127 /// Deserialize optional enum from string to optional protobuf i32
128 pub fn str_to_i32_opt<'de, D>(deserializer: D) -> core::result::Result<Option<i32>, D::Error>
129 where
130 D: serde::de::Deserializer<'de>,
131 {
132 let s: Option<&str> = serde::de::Deserialize::deserialize(deserializer)?;
133
134 if let Some(s) = s {
135 return Ok(Some(
136 s.to_lowercase()
137 .as_str()
138 .parse::<Self>()
139 .map_err(|_| serde::de::Error::unknown_variant(s, &[#expected_fields]))?
140 as i32));
141 }
142
143 Ok(None)
144 }
145 /// Serialize enum from protobuf i32 to json string
146 pub fn i32_to_str<S>(data: &i32, serializer: S) -> core::result::Result<S::Ok, S::Error>
147 where
148 S: serde::Serializer,
149 {
150 serializer.serialize_str(&Self::try_from(data.to_owned()).unwrap().to_string())
151 }
152 /// Serialize enum from optional protobuf i32 to optional json string
153 pub fn i32_to_str_opt<S>(
154 data: &Option<i32>,
155 serializer: S,
156 ) -> core::result::Result<S::Ok, S::Error>
157 where
158 S: serde::Serializer,
159 {
160 if let Some(ref d) = *data {
161 return serializer.serialize_str(&Self::try_from(d.to_owned()).unwrap().to_string());
162 }
163 serializer.serialize_none()
164 }
165 // Deserialize from a vec string to a vec i32
166 pub fn vec_str_to_vec_i32<'de, D>(deserializer: D) -> core::result::Result<Vec<i32>, D::Error>
167 where
168 D: serde::de::Deserializer<'de>,
169 {
170 let strings: Vec<&str> = serde::de::Deserialize::deserialize(deserializer)?;
171
172 let mut result = Vec::with_capacity(strings.len());
173
174 for s in strings {
175 match s.parse::<Self>() {
176 Ok(num) => result.push(num as i32),
177 Err(_) => {
178 return Err(serde::de::Error::invalid_value(
179 serde::de::Unexpected::Str(s),
180 &#expected_fields,
181 ))
182 }
183 }
184 }
185
186 Ok(result)
187 }
188 // Serializes a vec of enum i32s to a vec of strings
189 pub fn vec_i32_to_vec_str<S>(
190 data: &Vec<i32>,
191 serializer: S,
192 ) -> core::result::Result<S::Ok, S::Error>
193 where
194 S: serde::Serializer,
195 {
196 let mut seq = serializer.serialize_seq(Some(data.len()))?;
197
198 for &i in data {
199 serde::ser::SerializeSeq::serialize_element(
200 &mut seq,
201 &Self::try_from(i.to_owned()).unwrap().to_string(),
202 )?;
203 }
204
205 serde::ser::SerializeSeq::end(seq)
206 }
207 }
208
209 };
210
211 quote! {
212 #ast
213 #serde_funcs
214 }
215 .into()
216 }
217 // Check if the given item has the prost::Oneof attribute which indicates an enum nested inside a module
218 false => match is_prost_one_of {
219 true => {
220 let attribute: Attribute = parse_quote! {
221 #[derive(serde::Serialize, serde::Deserialize)]
222 };
223
224 let attribute2: Attribute = parse_quote!(
225 #[serde(rename_all = "snake_case")]
226 );
227
228 ast.attrs.push(attribute);
229 ast.attrs.push(attribute2);
230
231 // Generate a new TokenTree of data type DataEnum
232 let new = Data::Enum(DataEnum {
233 enum_token: de.enum_token,
234 brace_token: de.brace_token,
235 variants: de.variants.clone(),
236 });
237
238 // Copy over the attributes from the original enum for max compatibility
239 let new_enum = DeriveInput {
240 attrs: ast.attrs,
241 vis: ast.vis,
242 ident: ident.clone(),
243 generics: ast.generics,
244 data: new,
245 };
246
247 quote! {#new_enum}.into()
248 }
249 false => {
250 return ::syn::Error::new_spanned(
251 &ident,
252 "Could not parse the item as a valid protobuf Enum",
253 )
254 .to_compile_error()
255 .into();
256 }
257 },
258 }
259 }
260 // A struct that implements prost::Message
261 ::syn::Data::Struct(ref mut ds) => match &ds.fields {
262 Fields::Named(fields) => {
263 match is_prost_message {
264 true => {
265 let mut new_fields = fields.to_owned();
266
267 new_fields.named.iter_mut().for_each(|f| match is_option(f) {
268 true => {
269 match check_struct_field_for_prost_enumeration_attribute(&f.attrs) {
270 Some(a) => {
271 match check_is_vec(&f.ty) {
272 // check whether field is vec
273 true => {
274 let serializer = format!("{a}::vec_i32_to_vec_str");
275 let deserializer = format!("{a}::vec_str_to_vec_i32");
276
277 // Create a new serialize_with, deserialize_with attribute
278 let new_attr: Attribute = parse_quote! {
279 #[serde( default, deserialize_with = #deserializer, serialize_with = #serializer)]
280 };
281
282 f.attrs.push(new_attr);
283 },
284 false => {
285 let serializer = format!("{a}::i32_to_str_opt");
286 let deserializer = format!("{a}::str_to_i32_opt");
287
288 // Create a new serialize_with, deserialize_with attribute
289 let new_attr: Attribute = parse_quote! {
290 #[serde( default, deserialize_with = #deserializer, serialize_with = #serializer, skip_serializing_if = "Option::is_none" )]
291 };
292
293 f.attrs.push(new_attr);
294 }
295 }
296 },
297 None => ()
298 }
299 }
300 false => {
301
302 match check_struct_field_for_prost_enumeration_attribute(&f.attrs) {
303 Some(a) => {
304 match check_is_vec(&f.ty) {
305 true => {
306 let serializer = format!("{a}::vec_i32_to_vec_str");
307 let deserializer = format!("{a}::vec_str_to_vec_i32");
308 // Create a new serialize_with, deserialize_with attribute
309 let new_attr: Attribute = parse_quote! {
310 #[serde(default, deserialize_with = #deserializer, serialize_with = #serializer)]
311 };
312 f.attrs.push(new_attr);
313 },
314 false => {
315 let serializer = format!("{a}::i32_to_str");
316 let deserializer = format!("{a}::str_to_i32");
317 // Create a new serialize_with, deserialize_with attribute
318 let new_attr: Attribute = parse_quote! {
319 #[serde(default, deserialize_with = #deserializer, serialize_with = #serializer)]
320 };
321 f.attrs.push(new_attr);
322 }
323 }
324 },
325 None => ()
326 }
327 }
328 });
329
330 let new = Data::Struct(DataStruct {
331 struct_token: ds.struct_token,
332 fields: Fields::Named(new_fields),
333 semi_token: ds.semi_token,
334 });
335
336 let new_st = DeriveInput {
337 attrs: ast.attrs,
338 vis: ast.vis,
339 ident: ident.clone(),
340 generics: ast.generics,
341 data: new,
342 };
343
344 quote! {#new_st}.into()
345 }
346 false => {
347 return ::syn::Error::new_spanned(
348 &ident,
349 "ProtoJson only works with Protobuf Structs",
350 )
351 .to_compile_error()
352 .into();
353 }
354 }
355 }
356 _ => {
357 return ::syn::Error::new_spanned(
358 &ds.fields,
359 "ProtoJson only supports named field structs",
360 )
361 .to_compile_error()
362 .into();
363 }
364 },
365 _ => {
366 return ::syn::Error::new_spanned(
367 &ident,
368 "Only items with named fields can derive ProtoJson",
369 )
370 .to_compile_error()
371 .into();
372 }
373 };
374
375 ::proc_macro::TokenStream::from(generated)
376}
377
378/// Checks if a struct field is Optional
379fn is_option(field: &Field) -> bool {
380 let typ = &field.ty;
381
382 let opt = match typ {
383 Type::Path(typepath) if typepath.qself.is_none() => Some(typepath.path.clone()),
384 _ => None,
385 };
386
387 if let Some(o) = opt {
388 check_for_option(&o).is_some()
389 } else {
390 false
391 }
392}
393
394/// Walks the path segments to check for Option
395fn check_for_option(path: &Path) -> Option<&PathSegment> {
396 let idents_of_path = path.segments.iter().fold(String::new(), |mut acc, v| {
397 acc.push_str(&v.ident.to_string());
398 acc.push(':');
399 acc
400 });
401 vec!["Option:", "std:option:Option:", "core:option:Option:"]
402 .into_iter()
403 .find(|s| idents_of_path == *s)
404 .and_then(|_| path.segments.last())
405}
406
407
408/// Checks whether the attribute has a prost-enumeration member and returns an optional ident of the name
409fn check_struct_field_for_prost_enumeration_attribute(attrs: &[Attribute]) -> Option<String> {
410 // If a #[prost(enumeration = "value")] exists, this is the optional name of the value
411 let mut attrib: Option<String> = None;
412
413 //
414 for attr in attrs {
415 // Looks for attributes in the form #[prost(enumeration = "Currency", tag = "2")]
416 if attr.path().is_ident("prost") {
417 // Check for the "enumeration" part in the enumeration = "Currency" and parse it as KV.
418 attr.parse_nested_meta(|meta| {
419 // #[prost(enumeration = "Ident")]
420 if meta.path.is_ident("enumeration") {
421 // Get the Currency in enumeration = "Currency". Gives back a string literal (LitStr)
422 let value = meta.value().unwrap();
423 //
424 attrib = Some(value.parse::<LitStr>().unwrap().value());
425
426 return Ok(());
427 // If there is no enumeration attribute, do nothing
428 } else {
429 Ok(())
430 }
431 })
432 .unwrap_or(());
433 }
434 }
435 attrib
436}
437
438struct VecTypeVisitor {
439 is_vec: bool,
440}
441
442impl<'ast> Visit<'ast> for VecTypeVisitor {
443 fn visit_path_segment(&mut self, segment: &'ast syn::PathSegment) {
444 if segment.ident == "Vec" {
445 self.is_vec = true;
446 }
447 }
448}
449
450impl Fold for VecTypeVisitor {
451 fn fold_type_path(&mut self, type_path: TypePath) -> TypePath {
452 let mut new_segments = Punctuated::new();
453
454 for segment in type_path.path.segments {
455 new_segments.push(self.fold_path_segment(segment));
456 }
457
458 TypePath {
459 qself: type_path.qself,
460 path: syn::Path {
461 leading_colon: type_path.path.leading_colon,
462 segments: new_segments,
463 },
464 }
465 }
466
467 fn fold_path_segment(&mut self, segment: PathSegment) -> PathSegment {
468 let new_segment = segment.clone();
469 if segment.ident == "Vec" {
470 self.is_vec = true;
471 }
472 new_segment
473 }
474
475 fn fold_type(&mut self, ty: Type) -> Type {
476 match ty {
477 Type::Path(type_path) => Type::Path(self.fold_type_path(type_path)),
478 Type::Tuple(type_tuple) => {
479 let new_elems = type_tuple
480 .elems
481 .into_iter()
482 .map(|ty| self.fold_type(ty))
483 .collect();
484 Type::Tuple(syn::TypeTuple {
485 paren_token: type_tuple.paren_token,
486 elems: new_elems,
487 })
488 }
489 // Add cases for other types as needed
490 _ => ty,
491 }
492 }
493}
494
495fn check_is_vec(ty: &syn::Type) -> bool {
496 let mut visitor = VecTypeVisitor { is_vec: false };
497 fold_type(&mut visitor, ty.clone());
498 visitor.is_vec
499}