proto_json/
lib.rs

1extern crate proc_macro;
2extern crate syn;
3use quote::quote;
4use syn::{
5    fold::{fold_type, Fold},
6    parse_macro_input, parse_quote,
7    punctuated::Punctuated,
8    visit::Visit,
9    Attribute, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, LitStr, Path, PathSegment,
10    Type, TypePath,
11};
12
13/// Helps to glue together json and protobufs when placed on a valid prost::Message, prost::Enumeration, or prost::Oneof
14/// For structs, it walks the fields checking whether they are enums then adds serialize_with and deserialize_with attributes to relevant fields.
15/// Structs need to have the prost::Message attribute while Enums require the prost::Enumeration attribute
16/// For enums, it checks that the provided string is a valid variant then deserializes it as i32 to match the protobuf definitions
17/// # Example
18/// ```
19/// # #[macros::proto_json]
20/// pub struct Address {
21///     country : String,
22///     city : String,
23///     state : Option<String>,
24///     street : String,
25///     line1   : String,
26///     line2   : Option<String>
27/// # }
28/// ```
29/// Example with enums
30/// ```
31/// # #[macros::proto_json]
32/// pub enum Currency {
33///         USD = 0;
34///         GPB = 1;
35///         JPY = 2;
36/// # }
37/// ```
38#[proc_macro_attribute]
39pub fn proto_json(
40    _attr: proc_macro::TokenStream,
41    input: proc_macro::TokenStream,
42) -> ::proc_macro::TokenStream {
43    let mut ast = parse_macro_input!(input as DeriveInput);
44
45    let ident = &ast.ident;
46
47    let mut is_prost_message = false;
48    let mut is_prost_enumeration = false;
49    let mut is_prost_one_of = false;
50
51    // If item does not implement one of the following attributes, then return error as it is not a valid protobuf object.
52    for attrib in ast.attrs.iter() {
53        if attrib.path().is_ident("derive") {
54            attrib
55                .parse_nested_meta(|meta| {
56                    match meta.path.leading_colon {
57                        Some(_) => match meta
58                            .path
59                            .segments
60                            .last()
61                            .unwrap()
62                            .ident
63                            .to_string()
64                            .as_str()
65                        {
66                            "Message" => is_prost_message = true,
67                            "Enumeration" => is_prost_enumeration = true,
68                            "Oneof" => is_prost_one_of = true,
69                            _ => (),
70                        },
71                        None => (),
72                    }
73                    Ok(())
74                })
75                .unwrap();
76        }
77    }
78
79    let generated: proc_macro2::TokenStream = match ast.data {
80        Data::Enum(ref mut de) => {
81            match is_prost_enumeration {
82                // Implement the str_to_i32 and i32_to_str methods
83                true => {
84                    // Iterate over enum variants
85                    let variants = de.variants.iter().map(|v| &v.ident);
86
87                    // Convert variant name to snake_case
88                    let variant_str_as_i32 = variants.clone().map(|variant| {
89                        let variant_str = &::convert_case::Casing::to_case(
90                            &variant.to_string(),
91                            ::convert_case::Case::Snake,
92                        );
93                        // usd => Ok(Currency::Usd as i32)
94                        quote! (#variant_str => Ok(#ident::#variant as i32))
95                    });
96
97                    // Creates a list of the enum's fields. Used to give useful error messages when deserializing.
98                    let expected_fields = variants
99                        .clone()
100                        .map(|variant| {
101                            let variant_str = ::convert_case::Casing::to_case(
102                                &variant.to_string(),
103                                ::convert_case::Case::Snake,
104                            );
105                            variant_str
106                        })
107                        .into_iter()
108                        .map(|x| x)
109                        .collect::<Vec<String>>()
110                        .join(",");
111
112                    let serde_funcs = quote! {
113                        /// Methods for converting from Protofbuf to and from Json enums
114                        impl #ident {
115                            /// Deserialize enum from string to protobuf i32
116                            pub fn  str_to_i32<'de, D>(deserializer: D) -> core::result::Result<i32, D::Error>
117                            where
118                                D: serde::de::Deserializer<'de>,
119                            {
120                                let s: &str = serde::de::Deserialize::deserialize(deserializer)?;
121
122                                match s.to_lowercase().as_str() {
123                                    #(#variant_str_as_i32,)*
124                                    _ => core::result::Result::Err(serde::de::Error::unknown_variant(s, &[#expected_fields])),
125                                }
126                            }
127                            /// Deserialize optional enum from string to optional protobuf i32
128                            pub fn str_to_i32_opt<'de, D>(deserializer: D) -> core::result::Result<Option<i32>, D::Error>
129                            where
130                                D: serde::de::Deserializer<'de>,
131                            {
132                                let s: Option<&str> = serde::de::Deserialize::deserialize(deserializer)?;
133
134                                if let Some(s) = s {
135                                    return Ok(Some(
136                                        s.to_lowercase()
137                                            .as_str()
138                                            .parse::<Self>()
139                                            .map_err(|_| serde::de::Error::unknown_variant(s, &[#expected_fields]))?
140                                            as i32));
141                                }
142
143                                Ok(None)
144                            }
145                            /// Serialize enum from protobuf i32 to json string
146                            pub fn i32_to_str<S>(data: &i32, serializer: S) -> core::result::Result<S::Ok, S::Error>
147                            where
148                                S: serde::Serializer,
149                            {
150                                serializer.serialize_str(&Self::try_from(data.to_owned()).unwrap().to_string())
151                            }
152                            /// Serialize enum from optional protobuf i32 to optional json string
153                            pub fn i32_to_str_opt<S>(
154                                data: &Option<i32>,
155                                serializer: S,
156                            ) -> core::result::Result<S::Ok, S::Error>
157                            where
158                                S: serde::Serializer,
159                            {
160                                if let Some(ref d) = *data {
161                                    return serializer.serialize_str(&Self::try_from(d.to_owned()).unwrap().to_string());
162                                }
163                                serializer.serialize_none()
164                            }
165                            // Deserialize from a vec string to a vec i32
166                            pub fn vec_str_to_vec_i32<'de, D>(deserializer: D) -> core::result::Result<Vec<i32>, D::Error>
167                            where
168                                D: serde::de::Deserializer<'de>,
169                            {
170                                let strings: Vec<&str> = serde::de::Deserialize::deserialize(deserializer)?;
171
172                                let mut result = Vec::with_capacity(strings.len());
173
174                                for s in strings {
175                                    match s.parse::<Self>() {
176                                        Ok(num) => result.push(num as i32),
177                                        Err(_) => {
178                                            return Err(serde::de::Error::invalid_value(
179                                                serde::de::Unexpected::Str(s),
180                                                &#expected_fields,
181                                            ))
182                                        }
183                                    }
184                                }
185
186                                Ok(result)
187                            }
188                            // Serializes a vec of enum i32s to a vec of strings
189                            pub fn vec_i32_to_vec_str<S>(
190                                data: &Vec<i32>,
191                                serializer: S,
192                            ) -> core::result::Result<S::Ok, S::Error>
193                            where
194                                S: serde::Serializer,
195                            {
196                                let mut seq = serializer.serialize_seq(Some(data.len()))?;
197
198                                for &i in data {
199                                    serde::ser::SerializeSeq::serialize_element(
200                                        &mut seq,
201                                        &Self::try_from(i.to_owned()).unwrap().to_string(),
202                                    )?;
203                                }
204
205                                serde::ser::SerializeSeq::end(seq)
206                            }
207                        }
208
209                    };
210
211                    quote! {
212                        #ast
213                        #serde_funcs
214                    }
215                    .into()
216                }
217                // Check if the given item has the prost::Oneof attribute which indicates an enum nested inside a module
218                false => match is_prost_one_of {
219                    true => {
220                        let attribute: Attribute = parse_quote! {
221                            #[derive(serde::Serialize, serde::Deserialize)]
222                        };
223
224                        let attribute2: Attribute = parse_quote!(
225                            #[serde(rename_all = "snake_case")]
226                        );
227
228                        ast.attrs.push(attribute);
229                        ast.attrs.push(attribute2);
230
231                        // Generate a new TokenTree of data type DataEnum
232                        let new = Data::Enum(DataEnum {
233                            enum_token: de.enum_token,
234                            brace_token: de.brace_token,
235                            variants: de.variants.clone(),
236                        });
237
238                        // Copy over the attributes from the original enum for max compatibility
239                        let new_enum = DeriveInput {
240                            attrs: ast.attrs,
241                            vis: ast.vis,
242                            ident: ident.clone(),
243                            generics: ast.generics,
244                            data: new,
245                        };
246
247                        quote! {#new_enum}.into()
248                    }
249                    false => {
250                        return ::syn::Error::new_spanned(
251                            &ident,
252                            "Could not parse the item as a valid protobuf Enum",
253                        )
254                        .to_compile_error()
255                        .into();
256                    }
257                },
258            }
259        }
260        // A struct that implements prost::Message
261        ::syn::Data::Struct(ref mut ds) => match &ds.fields {
262            Fields::Named(fields) => {
263                match is_prost_message {
264                    true => {
265                        let mut new_fields = fields.to_owned();
266
267                        new_fields.named.iter_mut().for_each(|f| match is_option(f) {
268                            true => {
269                                match check_struct_field_for_prost_enumeration_attribute(&f.attrs) {
270                                    Some(a) => {
271                                        match check_is_vec(&f.ty) {
272                                            // check whether field is vec 
273                                            true => {
274                                                let serializer = format!("{a}::vec_i32_to_vec_str");
275                                                let deserializer = format!("{a}::vec_str_to_vec_i32");
276
277                                                // Create a new serialize_with, deserialize_with attribute
278                                                let new_attr: Attribute = parse_quote! {
279                                                    #[serde( default, deserialize_with = #deserializer, serialize_with = #serializer)]
280                                                };
281
282                                                f.attrs.push(new_attr);
283                                            },
284                                            false => {
285                                                let serializer = format!("{a}::i32_to_str_opt");
286                                                let deserializer = format!("{a}::str_to_i32_opt");
287
288                                                // Create a new serialize_with, deserialize_with attribute
289                                                let new_attr: Attribute = parse_quote! {
290                                                    #[serde( default, deserialize_with = #deserializer, serialize_with = #serializer, skip_serializing_if = "Option::is_none" )]
291                                                };
292
293                                                f.attrs.push(new_attr);
294                                            }
295                                        }
296                                    },
297                                    None => ()
298                                }
299                            }
300                            false => {
301
302                                match check_struct_field_for_prost_enumeration_attribute(&f.attrs) {
303                                    Some(a) => {
304                                        match check_is_vec(&f.ty) {
305                                            true => {
306                                                let serializer = format!("{a}::vec_i32_to_vec_str");
307                                                let deserializer = format!("{a}::vec_str_to_vec_i32");
308                                                // Create a new serialize_with, deserialize_with attribute
309                                                let new_attr: Attribute = parse_quote! {
310                                                    #[serde(default, deserialize_with = #deserializer, serialize_with = #serializer)]
311                                                };
312                                                f.attrs.push(new_attr);
313                                            },
314                                            false => {
315                                                let serializer = format!("{a}::i32_to_str");
316                                                let deserializer = format!("{a}::str_to_i32");
317                                                // Create a new serialize_with, deserialize_with attribute
318                                                let new_attr: Attribute = parse_quote! {
319                                                    #[serde(default, deserialize_with = #deserializer, serialize_with = #serializer)]
320                                                };
321                                                f.attrs.push(new_attr);
322                                            }
323                                        }
324                                    },
325                                    None => ()
326                                }
327                            }
328                        });
329
330                        let new = Data::Struct(DataStruct {
331                            struct_token: ds.struct_token,
332                            fields: Fields::Named(new_fields),
333                            semi_token: ds.semi_token,
334                        });
335
336                        let new_st = DeriveInput {
337                            attrs: ast.attrs,
338                            vis: ast.vis,
339                            ident: ident.clone(),
340                            generics: ast.generics,
341                            data: new,
342                        };
343
344                        quote! {#new_st}.into()
345                    }
346                    false => {
347                        return ::syn::Error::new_spanned(
348                            &ident,
349                            "ProtoJson only works with Protobuf Structs",
350                        )
351                        .to_compile_error()
352                        .into();
353                    }
354                }
355            }
356            _ => {
357                return ::syn::Error::new_spanned(
358                    &ds.fields,
359                    "ProtoJson only supports named field structs",
360                )
361                .to_compile_error()
362                .into();
363            }
364        },
365        _ => {
366            return ::syn::Error::new_spanned(
367                &ident,
368                "Only items with named fields can derive ProtoJson",
369            )
370            .to_compile_error()
371            .into();
372        }
373    };
374
375    ::proc_macro::TokenStream::from(generated)
376}
377
378/// Checks if a struct field is Optional
379fn is_option(field: &Field) -> bool {
380    let typ = &field.ty;
381
382    let opt = match typ {
383        Type::Path(typepath) if typepath.qself.is_none() => Some(typepath.path.clone()),
384        _ => None,
385    };
386
387    if let Some(o) = opt {
388        check_for_option(&o).is_some()
389    } else {
390        false
391    }
392}
393
394/// Walks the path segments to check for Option
395fn check_for_option(path: &Path) -> Option<&PathSegment> {
396    let idents_of_path = path.segments.iter().fold(String::new(), |mut acc, v| {
397        acc.push_str(&v.ident.to_string());
398        acc.push(':');
399        acc
400    });
401    vec!["Option:", "std:option:Option:", "core:option:Option:"]
402        .into_iter()
403        .find(|s| idents_of_path == *s)
404        .and_then(|_| path.segments.last())
405}
406
407
408/// Checks whether the attribute has a prost-enumeration member and returns an optional ident of the name
409fn check_struct_field_for_prost_enumeration_attribute(attrs: &[Attribute]) -> Option<String> {
410    // If a #[prost(enumeration = "value")] exists, this is the optional name of the value
411    let mut attrib: Option<String> = None;
412
413    //
414    for attr in attrs {
415        // Looks for attributes in the form #[prost(enumeration = "Currency", tag = "2")]
416        if attr.path().is_ident("prost") {
417            // Check for the "enumeration" part in the enumeration = "Currency" and parse it as KV.
418            attr.parse_nested_meta(|meta| {
419                // #[prost(enumeration = "Ident")]
420                if meta.path.is_ident("enumeration") {
421                    // Get the Currency in enumeration = "Currency". Gives back a string literal (LitStr)
422                    let value = meta.value().unwrap();
423                    //
424                    attrib = Some(value.parse::<LitStr>().unwrap().value());
425
426                    return Ok(());
427                // If there is no enumeration attribute, do nothing
428                } else {
429                    Ok(())
430                }
431            })
432            .unwrap_or(());
433        }
434    }
435    attrib
436}
437
438struct VecTypeVisitor {
439    is_vec: bool,
440}
441
442impl<'ast> Visit<'ast> for VecTypeVisitor {
443    fn visit_path_segment(&mut self, segment: &'ast syn::PathSegment) {
444        if segment.ident == "Vec" {
445            self.is_vec = true;
446        }
447    }
448}
449
450impl Fold for VecTypeVisitor {
451    fn fold_type_path(&mut self, type_path: TypePath) -> TypePath {
452        let mut new_segments = Punctuated::new();
453
454        for segment in type_path.path.segments {
455            new_segments.push(self.fold_path_segment(segment));
456        }
457
458        TypePath {
459            qself: type_path.qself,
460            path: syn::Path {
461                leading_colon: type_path.path.leading_colon,
462                segments: new_segments,
463            },
464        }
465    }
466
467    fn fold_path_segment(&mut self, segment: PathSegment) -> PathSegment {
468        let new_segment = segment.clone();
469        if segment.ident == "Vec" {
470            self.is_vec = true;
471        }
472        new_segment
473    }
474
475    fn fold_type(&mut self, ty: Type) -> Type {
476        match ty {
477            Type::Path(type_path) => Type::Path(self.fold_type_path(type_path)),
478            Type::Tuple(type_tuple) => {
479                let new_elems = type_tuple
480                    .elems
481                    .into_iter()
482                    .map(|ty| self.fold_type(ty))
483                    .collect();
484                Type::Tuple(syn::TypeTuple {
485                    paren_token: type_tuple.paren_token,
486                    elems: new_elems,
487                })
488            }
489            // Add cases for other types as needed
490            _ => ty,
491        }
492    }
493}
494
495fn check_is_vec(ty: &syn::Type) -> bool {
496    let mut visitor = VecTypeVisitor { is_vec: false };
497    fold_type(&mut visitor, ty.clone());
498    visitor.is_vec
499}