1extern crate proc_macro;
2
3use proc_macro2::{Span, TokenStream};
4use quote::{format_ident, quote, ToTokens, TokenStreamExt};
5
6#[proc_macro_attribute]
15pub fn amqp(
16 attr: proc_macro::TokenStream,
17 item: proc_macro::TokenStream,
18) -> proc_macro::TokenStream {
19 let (impls, attrs) = match syn::parse::<syn::Item>(item.clone()).unwrap() {
20 syn::Item::Enum(item) => (enum_serde(item), None),
21 syn::Item::Struct(item) => struct_serde(item, attr),
22 _ => panic!("amqp attribute can only be applied to enum or struct"),
23 };
24
25 let mut new = attrs.unwrap_or_else(proc_macro::TokenStream::new);
26 new.extend(item);
27 new.extend(impls);
28 new
29}
30
31fn enum_serde(def: syn::ItemEnum) -> proc_macro::TokenStream {
32 let name = &def.ident;
33 let (_, orig_ty_generics, _) = def.generics.split_for_impl();
34 let mut generics = def.generics.clone();
35 let mut lt_def = syn::LifetimeDef {
36 attrs: Vec::new(),
37 lifetime: syn::Lifetime::new("'de", Span::call_site()),
38 colon_token: None,
39 bounds: syn::punctuated::Punctuated::new(),
40 };
41
42 if def.generics.lifetimes().count() > 0 {
43 lt_def.bounds = def
44 .generics
45 .lifetimes()
46 .map(|def| def.lifetime.clone())
47 .collect();
48 }
49
50 generics.params = Some(syn::GenericParam::Lifetime(lt_def))
51 .into_iter()
52 .chain(generics.params)
53 .collect();
54
55 let de_life = syn::Lifetime::new("'de", Span::call_site());
56 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
57
58 let screaming = translate(&def.ident.to_string());
59 let scope = format_ident!("_IMPL_DESERIALIZER_FOR_{}", screaming);
60 let name_str = syn::LitStr::new(&name.to_string(), Span::call_site());
61
62 let mut field_variants = TokenStream::new();
63 for i in 0..def.variants.len() {
64 let name = format_ident!("F{}", i);
65 field_variants.append_all(quote!(#name,));
66 }
67
68 match def.variants.first().unwrap().fields {
69 syn::Fields::Unnamed(_) => {}
70 _ => panic!("struct variants are not supported"),
71 };
72
73 let mut tag_u64 = TokenStream::new();
74 let mut bytes_arms = TokenStream::new();
75 let mut variants = TokenStream::new();
76 let mut visitor_arms = TokenStream::new();
77
78 let mut int_arms = TokenStream::new();
79 for (i, var) in def.variants.iter().enumerate() {
80 let fields = match &var.fields {
81 syn::Fields::Unnamed(f) => f,
82 _ => panic!("only unnamed fields allowed here"),
83 };
84
85 if fields.unnamed.len() != 1 {
86 panic!("only 1 unnamed field is allowed");
87 }
88
89 let ty = match &fields.unnamed.first().unwrap().ty {
90 syn::Type::Path(p) => p,
91 p => panic!("only path types allowed: {}", p.into_token_stream()),
92 };
93
94 let variant = format_ident!("F{}", i);
95 let mut ty_name = ty.clone();
96 let mut segment = ty_name.path.segments.last_mut().unwrap();
97 segment.arguments = syn::PathArguments::None;
98 int_arms.append_all(quote!(#ty_name::CODE => std::result::Result::Ok(Field::#variant),));
99 bytes_arms.append_all(quote!(#ty_name::NAME => std::result::Result::Ok(Field::#variant),));
100
101 let variant_name = syn::LitStr::new(&var.ident.to_string(), Span::call_site());
102 variants.append_all(quote!(#variant_name,));
103
104 let var_ident = &var.ident;
105 visitor_arms.append_all(quote!(
106 (Field::#variant, __variant) => Result::map(
107 serde::de::VariantAccess::newtype_variant::<#ty_name>(__variant),
108 #name::#var_ident,
109 ),
110 ));
111 }
112
113 tag_u64.append_all(quote!(
114 fn visit_u64<E>(
115 self,
116 value: u64,
117 ) -> std::result::Result<Self::Value, E>
118 where
119 E: serde::de::Error,
120 {
121 match Some(value) {
122 #int_arms
123 _ => std::result::Result::Err(serde::de::Error::invalid_value(
124 serde::de::Unexpected::Unsigned(value),
125 &"invalid descriptor ID",
126 )),
127 }
128 }
129 ));
130
131 let res = quote!(
132 const #scope: () = {
133 use serde;
134 use std::fmt;
135
136 impl #impl_generics serde::Deserialize<#de_life> for #name #orig_ty_generics #where_clause {
137 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
138 where
139 D: serde::Deserializer<#de_life>,
140 {
141 enum Field { #field_variants }
142
143 struct FieldVisitor;
144
145 impl #impl_generics serde::de::Visitor<#de_life> for FieldVisitor {
146 type Value = Field;
147
148 fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
149 fmt::Formatter::write_str(fmt, "variant identifier")
150 }
151
152 #tag_u64
153
154 fn visit_bytes<E>(
155 self,
156 value: &[u8],
157 ) -> std::result::Result<Self::Value, E>
158 where
159 E: serde::de::Error,
160 {
161 match Some(value) {
162 #bytes_arms
163 _ => {
164 let value = std::string::String::from_utf8_lossy(value);
165 std::result::Result::Err(serde::de::Error::unknown_variant(
166 &value, VARIANTS,
167 ))
168 }
169 }
170 }
171 }
172
173 impl<#de_life> serde::Deserialize<#de_life> for Field {
174 #[inline]
175 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
176 where
177 D: serde::Deserializer<#de_life>,
178 {
179 serde::Deserializer::deserialize_identifier(deserializer, FieldVisitor)
180 }
181 }
182
183 struct Visitor #ty_generics {
184 marker: std::marker::PhantomData<#name#orig_ty_generics>,
185 lifetime: std::marker::PhantomData<&#de_life ()>,
186 }
187
188 impl #impl_generics serde::de::Visitor<#de_life> for Visitor #ty_generics {
189 type Value = #name #orig_ty_generics;
190 fn expecting(
191 &self,
192 fmt: &mut fmt::Formatter,
193 ) -> fmt::Result {
194 fmt::Formatter::write_str(fmt, "enum #name_str")
195 }
196 fn visit_enum<__A>(
197 self,
198 __data: __A,
199 ) -> std::result::Result<Self::Value, __A::Error>
200 where
201 __A: serde::de::EnumAccess<#de_life>,
202 {
203 match match serde::de::EnumAccess::variant(__data) {
204 std::result::Result::Ok(__val) => __val,
205 std::result::Result::Err(__err) => {
206 return std::result::Result::Err(__err);
207 }
208 } {
209 #visitor_arms
210 }
211
212 }
213
214 }
215
216 const VARIANTS: &[&'static str] = &[
217 #variants
218 ];
219
220 serde::Deserializer::deserialize_enum(
221 deserializer,
222 #name_str,
223 VARIANTS,
224 Visitor {
225 marker: std::marker::PhantomData::<#name#orig_ty_generics>,
226 lifetime: std::marker::PhantomData,
227 },
228 )
229 }
230 }
231 };
232 );
233
234 res.into()
235}
236
237fn struct_serde(
238 def: syn::ItemStruct,
239 meta: proc_macro::TokenStream,
240) -> (proc_macro::TokenStream, Option<proc_macro::TokenStream>) {
241 if meta.is_empty() {
242 panic!("no arguments found for attribute on struct type");
243 }
244
245 let list = syn::parse::<syn::MetaList>(meta).unwrap();
246 if !list.path.is_ident("descriptor") {
247 panic!("invalid attribute {:?}", list.path.get_ident().unwrap());
248 }
249
250 let (name, code) = if list.nested.len() == 2 {
251 let name = if let Some(syn::NestedMeta::Lit(syn::Lit::Str(s))) = list.nested.first() {
252 s.value()
253 } else {
254 panic!("could not extract descriptor name from attribute");
255 };
256
257 let id = if let Some(syn::NestedMeta::Lit(syn::Lit::Int(s))) = list.nested.last() {
258 s.clone()
259 } else {
260 panic!("could not extract descriptor ID from attribute");
261 };
262
263 (Some(name), Some(id))
264 } else {
265 assert_eq!(list.nested.len(), 1);
266 let pair =
267 if let Some(syn::NestedMeta::Meta(syn::Meta::NameValue(pair))) = list.nested.first() {
268 pair
269 } else {
270 panic!("could not extract descriptor name or code");
271 };
272
273 if pair.path.is_ident("name") {
274 if let syn::Lit::Str(s) = &pair.lit {
275 (Some(s.value()), None)
276 } else {
277 panic!("invalid type for descriptor name");
278 }
279 } else if pair.path.is_ident("code") {
280 if let syn::Lit::Int(s) = &pair.lit {
281 (None, Some(s.clone()))
282 } else {
283 panic!("invalid type for descriptor name");
284 }
285 } else {
286 panic!(
287 "invalid descriptor element {:?}",
288 pair.path.get_ident().unwrap()
289 );
290 }
291 };
292
293 let ident = def.ident;
294 let generics = def.generics;
295
296 let renamed = format!(
297 "{}|{}",
298 name.clone().unwrap_or_else(|| "".into()),
299 code.clone()
300 .map_or("".into(), |i| i.base10_digits().to_string())
301 );
302 let none = quote!(None);
303 let name = name.map_or(none.clone(), |s| {
304 let lit = syn::LitByteStr::new(s.as_bytes(), Span::call_site());
305 quote!(Some(#lit))
306 });
307 let code = code.map_or(none, |i| quote!(Some(#i)));
308
309 let described = quote!(
310 impl#generics Described for #ident#generics {
311 const NAME: Option<&'static [u8]> = #name;
312 const CODE: Option<u64> = #code;
313 }
314 );
315
316 let rename = quote!(#[derive(Deserialize)] #[serde(rename = #renamed)]);
317 (described.into(), Some(rename.into()))
318}
319
320fn translate(s: &str) -> String {
321 let mut snake = String::new();
322 for (i, ch) in s.char_indices() {
323 if i > 0 && ch.is_uppercase() {
324 snake.push('_');
325 }
326 snake.push(ch.to_ascii_uppercase());
327 }
328 snake
329}