battler_wamprat_error_proc_macro/
lib.rs

1use battler_wamp_uri::Uri;
2use proc_macro2::{
3    Ident,
4    Span,
5    TokenStream,
6};
7use quote::quote;
8use syn::{
9    Data,
10    DeriveInput,
11    Error,
12    Fields,
13    LitStr,
14    Result,
15    parse::{
16        Parse,
17        ParseStream,
18    },
19    parse_macro_input,
20    spanned::Spanned,
21};
22
23struct StructInput {
24    ident: Ident,
25    uri: LitStr,
26    fields: Fields,
27}
28
29struct EnumVariant {
30    ident: Ident,
31    uri: LitStr,
32    fields: Fields,
33    span: Span,
34}
35
36struct EnumInput {
37    ident: Ident,
38    variants: Vec<EnumVariant>,
39}
40
41enum Input {
42    Struct(StructInput),
43    Enum(EnumInput),
44}
45
46impl Parse for Input {
47    fn parse(input: ParseStream) -> Result<Self> {
48        let call_site = Span::call_site();
49        let input = match DeriveInput::parse(input) {
50            Ok(item) => item,
51            Err(_) => return Err(Error::new(call_site, "input must be derive macro input")),
52        };
53        match input.data {
54            Data::Struct(data) => {
55                let ident = input.ident;
56                let uri = input
57                    .attrs
58                    .iter()
59                    .find(|attr| attr.path().is_ident("uri"))
60                    .and_then(|attr| {
61                        Some(attr.parse_args_with(|input: ParseStream| input.parse::<LitStr>()))
62                    })
63                    .ok_or_else(|| Error::new(call_site, "missing uri attribute"))??;
64                let fields = data.fields;
65                Ok(Self::Struct(StructInput { ident, uri, fields }))
66            }
67            Data::Enum(data) => {
68                let ident = input.ident;
69                let variants =
70                    data.variants
71                        .into_iter()
72                        .map(|variant| {
73                            let span = variant.span();
74                            let ident = variant.ident;
75                            let uri = variant
76                                .attrs
77                                .iter()
78                                .find(|attr| attr.path().is_ident("uri"))
79                                .and_then(|attr| {
80                                    Some(attr.parse_args_with(|input: ParseStream| {
81                                        input.parse::<LitStr>()
82                                    }))
83                                })
84                                .ok_or_else(|| Error::new(span, "missing uri attribute"))??;
85                            let fields = variant.fields;
86                            Ok(EnumVariant {
87                                ident,
88                                uri,
89                                fields,
90                                span,
91                            })
92                        })
93                        .collect::<Result<Vec<_>>>()?;
94                Ok(Self::Enum(EnumInput { ident, variants }))
95            }
96            Data::Union(_) => return Err(Error::new(call_site, "macro not allowed on a union")),
97        }
98    }
99}
100
101fn construct_from_fields(ty: TokenStream, fields: &Fields, span: Span) -> Result<TokenStream> {
102    match fields {
103        Fields::Named(fields) => {
104            if fields.named.len() > 1 {
105                return Err(Error::new(
106                    span,
107                    "struct must be constructible from a string",
108                ));
109            }
110            match fields.named.get(0) {
111                Some(field) => {
112                    let ident = field.ident.as_ref().unwrap();
113                    Ok(quote! {
114                        Ok(#ty { #ident: value.message().into() })
115                    })
116                }
117                None => Ok(quote! { Ok(#ty {})}),
118            }
119        }
120        Fields::Unnamed(fields) => {
121            if fields.unnamed.len() > 1 {
122                return Err(Error::new(
123                    span,
124                    "struct must be constructible from a string",
125                ));
126            }
127            if fields.unnamed.len() == 1 {
128                Ok(quote! {
129                    Ok(#ty(value.message().into()))
130                })
131            } else {
132                Ok(quote! {
133                    Ok(#ty())
134                })
135            }
136        }
137        Fields::Unit => Ok(quote! { Ok(#ty) }),
138    }
139}
140
141/// Procedural macro for generating conversions to and from
142/// [`battler_wamp::core::error::WampError`].
143#[proc_macro_derive(WampError, attributes(uri))]
144pub fn derive_wamp_uri_matcher(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
145    let input = parse_macro_input!(input as Input);
146    let call_site = Span::call_site();
147    let ident = match &input {
148        Input::Struct(input) => input.ident.clone(),
149        Input::Enum(input) => input.ident.clone(),
150    };
151
152    match &input {
153        Input::Struct(input) => {
154            if Uri::try_from(input.uri.value()).is_err() {
155                return proc_macro::TokenStream::from(
156                    Error::new(call_site, "invalid uri").into_compile_error(),
157                );
158            }
159        }
160        Input::Enum(input) => {
161            for variant in &input.variants {
162                if Uri::try_from(variant.uri.value()).is_err() {
163                    return proc_macro::TokenStream::from(
164                        Error::new(variant.span, "invalid uri").into_compile_error(),
165                    );
166                }
167            }
168        }
169    }
170
171    let into = match &input {
172        Input::Struct(input) => {
173            let uri = &input.uri;
174            quote! {
175                ::battler_wamp::core::error::WampError::new(::battler_wamp_uri::Uri::try_from(#uri).unwrap(), self.to_string())
176            }
177        }
178        Input::Enum(input) => {
179            let variants = input.variants.iter().map(|variant| {
180                let ident = &variant.ident;
181                let uri = &variant.uri;
182                quote! {
183                    Self::#ident { .. } => ::battler_wamp::core::error::WampError::new(::battler_wamp_uri::Uri::try_from(#uri).unwrap(), self.to_string())
184                }
185            });
186            quote! {
187                match self {
188                    #(#variants),*
189                }
190            }
191        }
192    };
193
194    let try_from = match &input {
195        Input::Struct(input) => {
196            let constructor = match construct_from_fields(quote!(Self), &input.fields, call_site) {
197                Ok(constructor) => constructor,
198                Err(err) => return proc_macro::TokenStream::from(err.into_compile_error()),
199            };
200            let uri = &input.uri;
201            quote! {
202                if value.reason().as_ref() == #uri {
203                    #constructor
204                } else {
205                    Err(value)
206                }
207            }
208        }
209        Input::Enum(input) => {
210            let variant_matchers = match input
211                .variants
212                .iter()
213                .map(|variant| {
214                    let ident = &variant.ident;
215                    let constructor =
216                        construct_from_fields(quote!(Self::#ident), &variant.fields, call_site)?;
217                    let uri = &variant.uri;
218                    Ok(quote! {
219                        if value.reason().as_ref() == #uri {
220                            return #constructor;
221                        }
222                    })
223                })
224                .collect::<Result<Vec<_>>>()
225            {
226                Ok(variant_matchers) => variant_matchers,
227                Err(err) => return proc_macro::TokenStream::from(err.into_compile_error()),
228            };
229            quote! {
230                #(#variant_matchers)*
231                Err(value)
232            }
233        }
234    };
235
236    quote! {
237        impl ::core::convert::Into<::battler_wamp::core::error::WampError> for #ident where #ident: ::std::string::ToString {
238            fn into(self) -> ::battler_wamp::core::error::WampError {
239                #into
240            }
241        }
242
243        impl ::core::convert::TryFrom<::battler_wamp::core::error::WampError> for #ident {
244            type Error = ::battler_wamp::core::error::WampError;
245            fn try_from(value: ::battler_wamp::core::error::WampError) -> ::core::result::Result<Self, Self::Error> {
246                #try_from
247            }
248        }
249    }
250    .into()
251}