1extern crate proc_macro;
2
3mod parse;
4
5use parse::ProstEnum;
6use proc_macro::TokenStream;
7use proc_macro2::Ident;
8use proc_macro_error::proc_macro_error;
9use quote::quote;
10use syn::parse_macro_input;
11
12#[proc_macro_attribute]
13#[proc_macro_error]
14pub fn enhance(_args: TokenStream, input: TokenStream) -> TokenStream {
15 let prost_enum = {
16 let input = input.clone();
17 parse_macro_input!(input as ProstEnum)
18 };
19
20 let mut output = proc_macro2::TokenStream::new();
21 match prost_enum.repr {
22 Some(_) => {
23 output.extend(Some(quote! {
24 #[derive(prost_enum::Serialize_enum, prost_enum::Deserialize_enum)]
25 }));
26 #[cfg(feature = "sea-orm")]
27 output.extend(Some(quote! {
28 #[derive(sea_orm::entity::prelude::EnumIter, sea_orm::entity::prelude::DeriveActiveEnum)]
29 #[sea_orm(rs_type = "i32", db_type = "Integer")]
30 }));
31 }
32 None => output.extend(Some(quote! {
33 #[derive(serde::Serialize, serde::Deserialize)]
34 })),
35 }
36 output.extend(proc_macro2::TokenStream::from(input));
37 output.into()
38}
39
40#[proc_macro_derive(Serialize_enum)]
41pub fn derive_serialize(input: TokenStream) -> TokenStream {
42 let input = parse_macro_input!(input as ProstEnum);
43
44 match input.repr {
45 Some(_) => gen_serialize(input.ident),
46 None => TokenStream::from(quote! {}),
47 }
48}
49
50#[proc_macro_derive(Deserialize_enum, attributes(serde))]
51pub fn derive_deserialize(input: TokenStream) -> TokenStream {
52 let input = parse_macro_input!(input as ProstEnum);
53
54 match input.repr {
55 Some(_) => gen_deserialize(input.ident),
56 None => TokenStream::from(quote! {}),
57 }
58}
59
60fn gen_serialize(ident: Ident) -> TokenStream {
61 TokenStream::from(quote! {
62 impl serde::Serialize for #ident {
63 #[allow(clippy::use_self)]
64 fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
65 where
66 S: serde::Serializer
67 {
68 let value = self.as_str_name();
69 serde::Serialize::serialize(&value, serializer)
70 }
71 }
72 })
73}
74
75fn gen_deserialize(ident: Ident) -> TokenStream {
76 TokenStream::from(quote! {
77 impl<'de> serde::Deserialize<'de> for #ident {
78 #[allow(clippy::use_self)]
79 fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
80 where
81 D: serde::Deserializer<'de>,
82 {
83 struct discriminant;
84
85 impl<'de> serde::de::Visitor<'de> for discriminant {
86 type Value = #ident;
87
88 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
89 write!(formatter, "a string or an integer")
90 }
91
92 fn visit_str<R>(self, v: &str) -> Result<Self::Value, R>
93 where
94 R: serde::de::Error,
95 {
96 if v.is_empty() {
97 match #ident::from_i32(0) {
98 Some(e) => Ok(e),
99 None => Err(serde::de::Error::custom(format!(
100 "unknown enum value: {}",
101 v
102 )))
103 }
104 } else {
105 match #ident::from_str_name(v) {
106 Some(e) => Ok(e),
107 None => Err(serde::de::Error::custom(format!(
108 "unknown enum value: {}",
109 v
110 ))),
111 }
112 }
113 }
114
115 fn visit_i64<R>(self, v: i64) -> Result<Self::Value, R>
116 where
117 R: serde::de::Error,
118 {
119 match #ident::from_i32(v as i32) {
120 Some(e) => Ok(e),
121 None => Err(serde::de::Error::custom(format!(
122 "unknown enum value: {}",
123 v
124 )))
125 }
126 }
127
128 fn visit_u64<R>(self, v: u64) -> Result<Self::Value, R>
129 where
130 R: serde::de::Error,
131 {
132 match #ident::from_i32(v as i32) {
133 Some(e) => Ok(e),
134 None => Err(serde::de::Error::custom(format!(
135 "unknown enum value: {}",
136 v
137 )))
138 }
139 }
140 }
141
142 deserializer.deserialize_any(discriminant)
143 }
144 }
145 })
146}