oasis_amqp_macros/
lib.rs

1extern crate proc_macro;
2
3use proc_macro2::{Span, TokenStream};
4use quote::{format_ident, quote, ToTokens, TokenStreamExt};
5
6/// Implement AMQP 1.0-related functionality for structs and enums
7///
8/// For enums: this macro provides a custom implementation of serde::Deserialize. Only newtype
9/// variants and unit variants are supported; all variants within an enum should be of the same
10/// type.
11///
12/// For structs: this macro is used to implement the `oasis-amqp::Described` trait. It also
13/// ensures `serde::Deserialize` is implemented for a type.
14#[proc_macro_attribute]
15pub fn amqp(
16    attr: proc_macro::TokenStream,
17    item: proc_macro::TokenStream,
18) -> proc_macro::TokenStream {
19    let (impls, attrs) = match syn::parse::<syn::Item>(item.clone()).unwrap() {
20        syn::Item::Enum(item) => (enum_serde(item), None),
21        syn::Item::Struct(item) => struct_serde(item, attr),
22        _ => panic!("amqp attribute can only be applied to enum or struct"),
23    };
24
25    let mut new = attrs.unwrap_or_else(proc_macro::TokenStream::new);
26    new.extend(item);
27    new.extend(impls);
28    new
29}
30
31fn enum_serde(def: syn::ItemEnum) -> proc_macro::TokenStream {
32    let name = &def.ident;
33    let (_, orig_ty_generics, _) = def.generics.split_for_impl();
34    let mut generics = def.generics.clone();
35    let mut lt_def = syn::LifetimeDef {
36        attrs: Vec::new(),
37        lifetime: syn::Lifetime::new("'de", Span::call_site()),
38        colon_token: None,
39        bounds: syn::punctuated::Punctuated::new(),
40    };
41
42    if def.generics.lifetimes().count() > 0 {
43        lt_def.bounds = def
44            .generics
45            .lifetimes()
46            .map(|def| def.lifetime.clone())
47            .collect();
48    }
49
50    generics.params = Some(syn::GenericParam::Lifetime(lt_def))
51        .into_iter()
52        .chain(generics.params)
53        .collect();
54
55    let de_life = syn::Lifetime::new("'de", Span::call_site());
56    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
57
58    let screaming = translate(&def.ident.to_string());
59    let scope = format_ident!("_IMPL_DESERIALIZER_FOR_{}", screaming);
60    let name_str = syn::LitStr::new(&name.to_string(), Span::call_site());
61
62    let mut field_variants = TokenStream::new();
63    for i in 0..def.variants.len() {
64        let name = format_ident!("F{}", i);
65        field_variants.append_all(quote!(#name,));
66    }
67
68    match def.variants.first().unwrap().fields {
69        syn::Fields::Unnamed(_) => {}
70        _ => panic!("struct variants are not supported"),
71    };
72
73    let mut tag_u64 = TokenStream::new();
74    let mut bytes_arms = TokenStream::new();
75    let mut variants = TokenStream::new();
76    let mut visitor_arms = TokenStream::new();
77
78    let mut int_arms = TokenStream::new();
79    for (i, var) in def.variants.iter().enumerate() {
80        let fields = match &var.fields {
81            syn::Fields::Unnamed(f) => f,
82            _ => panic!("only unnamed fields allowed here"),
83        };
84
85        if fields.unnamed.len() != 1 {
86            panic!("only 1 unnamed field is allowed");
87        }
88
89        let ty = match &fields.unnamed.first().unwrap().ty {
90            syn::Type::Path(p) => p,
91            p => panic!("only path types allowed: {}", p.into_token_stream()),
92        };
93
94        let variant = format_ident!("F{}", i);
95        let mut ty_name = ty.clone();
96        let mut segment = ty_name.path.segments.last_mut().unwrap();
97        segment.arguments = syn::PathArguments::None;
98        int_arms.append_all(quote!(#ty_name::CODE => std::result::Result::Ok(Field::#variant),));
99        bytes_arms.append_all(quote!(#ty_name::NAME => std::result::Result::Ok(Field::#variant),));
100
101        let variant_name = syn::LitStr::new(&var.ident.to_string(), Span::call_site());
102        variants.append_all(quote!(#variant_name,));
103
104        let var_ident = &var.ident;
105        visitor_arms.append_all(quote!(
106            (Field::#variant, __variant) => Result::map(
107                serde::de::VariantAccess::newtype_variant::<#ty_name>(__variant),
108                #name::#var_ident,
109            ),
110        ));
111    }
112
113    tag_u64.append_all(quote!(
114        fn visit_u64<E>(
115            self,
116            value: u64,
117        ) -> std::result::Result<Self::Value, E>
118        where
119            E: serde::de::Error,
120        {
121            match Some(value) {
122                #int_arms
123                _ => std::result::Result::Err(serde::de::Error::invalid_value(
124                    serde::de::Unexpected::Unsigned(value),
125                    &"invalid descriptor ID",
126                )),
127            }
128        }
129    ));
130
131    let res = quote!(
132        const #scope: () = {
133            use serde;
134            use std::fmt;
135
136            impl #impl_generics serde::Deserialize<#de_life> for #name #orig_ty_generics #where_clause {
137                fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
138                where
139                    D: serde::Deserializer<#de_life>,
140                {
141                    enum Field { #field_variants }
142
143                    struct FieldVisitor;
144
145                    impl #impl_generics serde::de::Visitor<#de_life> for FieldVisitor {
146                        type Value = Field;
147
148                        fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
149                            fmt::Formatter::write_str(fmt, "variant identifier")
150                        }
151
152                        #tag_u64
153
154                        fn visit_bytes<E>(
155                            self,
156                            value: &[u8],
157                        ) -> std::result::Result<Self::Value, E>
158                        where
159                            E: serde::de::Error,
160                        {
161                            match Some(value) {
162                                #bytes_arms
163                                _ => {
164                                    let value = std::string::String::from_utf8_lossy(value);
165                                    std::result::Result::Err(serde::de::Error::unknown_variant(
166                                        &value, VARIANTS,
167                                    ))
168                                }
169                            }
170                        }
171                    }
172
173                    impl<#de_life> serde::Deserialize<#de_life> for Field {
174                        #[inline]
175                        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
176                        where
177                            D: serde::Deserializer<#de_life>,
178                        {
179                            serde::Deserializer::deserialize_identifier(deserializer, FieldVisitor)
180                        }
181                    }
182
183                    struct Visitor #ty_generics {
184                        marker: std::marker::PhantomData<#name#orig_ty_generics>,
185                        lifetime: std::marker::PhantomData<&#de_life ()>,
186                    }
187
188                    impl #impl_generics serde::de::Visitor<#de_life> for Visitor #ty_generics {
189                        type Value = #name #orig_ty_generics;
190                        fn expecting(
191                            &self,
192                            fmt: &mut fmt::Formatter,
193                        ) -> fmt::Result {
194                            fmt::Formatter::write_str(fmt, "enum #name_str")
195                        }
196                        fn visit_enum<__A>(
197                            self,
198                            __data: __A,
199                        ) -> std::result::Result<Self::Value, __A::Error>
200                        where
201                            __A: serde::de::EnumAccess<#de_life>,
202                        {
203                            match match serde::de::EnumAccess::variant(__data) {
204                                std::result::Result::Ok(__val) => __val,
205                                std::result::Result::Err(__err) => {
206                                    return std::result::Result::Err(__err);
207                                }
208                            } {
209                                #visitor_arms
210                            }
211
212                        }
213
214                    }
215
216                    const VARIANTS: &[&'static str] = &[
217                        #variants
218                    ];
219
220                    serde::Deserializer::deserialize_enum(
221                        deserializer,
222                        #name_str,
223                        VARIANTS,
224                        Visitor {
225                            marker: std::marker::PhantomData::<#name#orig_ty_generics>,
226                            lifetime: std::marker::PhantomData,
227                        },
228                    )
229                }
230            }
231        };
232    );
233
234    res.into()
235}
236
237fn struct_serde(
238    def: syn::ItemStruct,
239    meta: proc_macro::TokenStream,
240) -> (proc_macro::TokenStream, Option<proc_macro::TokenStream>) {
241    if meta.is_empty() {
242        panic!("no arguments found for attribute on struct type");
243    }
244
245    let list = syn::parse::<syn::MetaList>(meta).unwrap();
246    if !list.path.is_ident("descriptor") {
247        panic!("invalid attribute {:?}", list.path.get_ident().unwrap());
248    }
249
250    let (name, code) = if list.nested.len() == 2 {
251        let name = if let Some(syn::NestedMeta::Lit(syn::Lit::Str(s))) = list.nested.first() {
252            s.value()
253        } else {
254            panic!("could not extract descriptor name from attribute");
255        };
256
257        let id = if let Some(syn::NestedMeta::Lit(syn::Lit::Int(s))) = list.nested.last() {
258            s.clone()
259        } else {
260            panic!("could not extract descriptor ID from attribute");
261        };
262
263        (Some(name), Some(id))
264    } else {
265        assert_eq!(list.nested.len(), 1);
266        let pair =
267            if let Some(syn::NestedMeta::Meta(syn::Meta::NameValue(pair))) = list.nested.first() {
268                pair
269            } else {
270                panic!("could not extract descriptor name or code");
271            };
272
273        if pair.path.is_ident("name") {
274            if let syn::Lit::Str(s) = &pair.lit {
275                (Some(s.value()), None)
276            } else {
277                panic!("invalid type for descriptor name");
278            }
279        } else if pair.path.is_ident("code") {
280            if let syn::Lit::Int(s) = &pair.lit {
281                (None, Some(s.clone()))
282            } else {
283                panic!("invalid type for descriptor name");
284            }
285        } else {
286            panic!(
287                "invalid descriptor element {:?}",
288                pair.path.get_ident().unwrap()
289            );
290        }
291    };
292
293    let ident = def.ident;
294    let generics = def.generics;
295
296    let renamed = format!(
297        "{}|{}",
298        name.clone().unwrap_or_else(|| "".into()),
299        code.clone()
300            .map_or("".into(), |i| i.base10_digits().to_string())
301    );
302    let none = quote!(None);
303    let name = name.map_or(none.clone(), |s| {
304        let lit = syn::LitByteStr::new(s.as_bytes(), Span::call_site());
305        quote!(Some(#lit))
306    });
307    let code = code.map_or(none, |i| quote!(Some(#i)));
308
309    let described = quote!(
310        impl#generics Described for #ident#generics {
311            const NAME: Option<&'static [u8]> = #name;
312            const CODE: Option<u64> = #code;
313        }
314    );
315
316    let rename = quote!(#[derive(Deserialize)] #[serde(rename = #renamed)]);
317    (described.into(), Some(rename.into()))
318}
319
320fn translate(s: &str) -> String {
321    let mut snake = String::new();
322    for (i, ch) in s.char_indices() {
323        if i > 0 && ch.is_uppercase() {
324            snake.push('_');
325        }
326        snake.push(ch.to_ascii_uppercase());
327    }
328    snake
329}