prost_enum/
lib.rs

1extern crate proc_macro;
2
3mod parse;
4
5use parse::ProstEnum;
6use proc_macro::TokenStream;
7use proc_macro2::Ident;
8use proc_macro_error::proc_macro_error;
9use quote::quote;
10use syn::parse_macro_input;
11
12#[proc_macro_attribute]
13#[proc_macro_error]
14pub fn enhance(_args: TokenStream, input: TokenStream) -> TokenStream {
15    let prost_enum = {
16        let input = input.clone();
17        parse_macro_input!(input as ProstEnum)
18    };
19
20    let mut output = proc_macro2::TokenStream::new();
21    match prost_enum.repr {
22        Some(_) => {
23            output.extend(Some(quote! {
24                #[derive(prost_enum::Serialize_enum, prost_enum::Deserialize_enum)]
25            }));
26            #[cfg(feature = "sea-orm")]
27            output.extend(Some(quote! {
28                #[derive(sea_orm::entity::prelude::EnumIter, sea_orm::entity::prelude::DeriveActiveEnum)]
29                #[sea_orm(rs_type = "i32", db_type = "Integer")]
30            }));
31        }
32        None => output.extend(Some(quote! {
33            #[derive(serde::Serialize, serde::Deserialize)]
34        })),
35    }
36    output.extend(proc_macro2::TokenStream::from(input));
37    output.into()
38}
39
40#[proc_macro_derive(Serialize_enum)]
41pub fn derive_serialize(input: TokenStream) -> TokenStream {
42    let input = parse_macro_input!(input as ProstEnum);
43
44    match input.repr {
45        Some(_) => gen_serialize(input.ident),
46        None => TokenStream::from(quote! {}),
47    }
48}
49
50#[proc_macro_derive(Deserialize_enum, attributes(serde))]
51pub fn derive_deserialize(input: TokenStream) -> TokenStream {
52    let input = parse_macro_input!(input as ProstEnum);
53
54    match input.repr {
55        Some(_) => gen_deserialize(input.ident),
56        None => TokenStream::from(quote! {}),
57    }
58}
59
60fn gen_serialize(ident: Ident) -> TokenStream {
61    TokenStream::from(quote! {
62        impl serde::Serialize for #ident {
63            #[allow(clippy::use_self)]
64            fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
65            where
66                S: serde::Serializer
67            {
68                let value = self.as_str_name();
69                serde::Serialize::serialize(&value, serializer)
70            }
71        }
72    })
73}
74
75fn gen_deserialize(ident: Ident) -> TokenStream {
76    TokenStream::from(quote! {
77        impl<'de> serde::Deserialize<'de> for #ident {
78            #[allow(clippy::use_self)]
79            fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
80            where
81                D: serde::Deserializer<'de>,
82            {
83                struct discriminant;
84
85                impl<'de> serde::de::Visitor<'de> for discriminant {
86                    type Value = #ident;
87
88                    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
89                        write!(formatter, "a string or an integer")
90                    }
91
92                    fn visit_str<R>(self, v: &str) -> Result<Self::Value, R>
93                    where
94                        R: serde::de::Error,
95                    {
96                        if v.is_empty() {
97                            match #ident::from_i32(0) {
98                                Some(e) => Ok(e),
99                                None => Err(serde::de::Error::custom(format!(
100                                    "unknown enum value: {}",
101                                    v
102                                )))
103                            }
104                        } else {
105                            match #ident::from_str_name(v) {
106                                Some(e) => Ok(e),
107                                None => Err(serde::de::Error::custom(format!(
108                                    "unknown enum value: {}",
109                                    v
110                                ))),
111                            }
112                        }
113                    }
114
115                    fn visit_i64<R>(self, v: i64) -> Result<Self::Value, R>
116                    where
117                        R: serde::de::Error,
118                    {
119                        match #ident::from_i32(v as i32) {
120                            Some(e) => Ok(e),
121                            None => Err(serde::de::Error::custom(format!(
122                                "unknown enum value: {}",
123                                v
124                            )))
125                        }
126                    }
127
128                    fn visit_u64<R>(self, v: u64) -> Result<Self::Value, R>
129                    where
130                        R: serde::de::Error,
131                    {
132                        match #ident::from_i32(v as i32) {
133                            Some(e) => Ok(e),
134                            None => Err(serde::de::Error::custom(format!(
135                                "unknown enum value: {}",
136                                v
137                            )))
138                        }
139                    }
140                }
141
142                deserializer.deserialize_any(discriminant)
143            }
144        }
145    })
146}