rocket_extra_codegen/
lib.rs

1#![recursion_limit = "128"]
2
3extern crate proc_macro;
4
5static ERROR_TYPE_ATTRIBUTE: &str = "error_type";
6
7use crate::proc_macro::TokenStream;
8use quote::quote;
9use syn;
10use syn::Attribute;
11use syn::Data;
12use syn::DeriveInput;
13use syn::Error;
14use syn::Fields;
15use syn::Lit;
16use syn::Meta;
17use syn::Type;
18
19#[proc_macro_derive(FromRequest, attributes(error_type))]
20pub fn derive_from_request(input: TokenStream) -> TokenStream {
21    try_derive_from_request(input).unwrap_or_else(|err| err.to_compile_error().into())
22}
23
24fn try_derive_from_request(input: TokenStream) -> Result<TokenStream, Error> {
25    let ast = syn::parse::<DeriveInput>(input)?;
26
27    let name = ast.ident;
28    let fields = match ast.data {
29        Data::Struct(struct_) => struct_.fields,
30        Data::Enum(enum_) => return Err(Error::new_spanned(enum_.enum_token, "Should be a struct")),
31        Data::Union(union_) => {
32            return Err(Error::new_spanned(union_.union_token, "Should be a struct"));
33        }
34    };
35
36    if !ast.generics.params.is_empty() {
37        return Err(Error::new_spanned(
38            ast.generics,
39            "Generics are not yet supported",
40        ));
41    }
42
43    if let Fields::Unnamed(unnamed) = fields {
44        return Err(Error::new_spanned(
45            unnamed,
46            "Should be a struct with named fields",
47        ));
48    }
49
50    let error_type_declaration = match get_error_type(&ast.attrs)? {
51        Some(error_type) => quote! { type Error = #error_type; },
52        None => quote! { type Error = (); },
53    };
54
55    let mut arms = Vec::new();
56    let mut constructor = Vec::new();
57
58    for field in &fields {
59        let name = field.ident.as_ref().expect("Unexpected unnamed field");
60        let ty = &field.ty;
61        arms.push(quote! {
62            let #name = match ::rocket::Request::guard::<#ty>(request) {
63                ::rocket::Outcome::Success(user) => user,
64                ::rocket::Outcome::Failure((status, error)) => return ::rocket::Outcome::Failure((status, ::std::convert::From::from(error))),
65                ::rocket::Outcome::Forward(()) => return ::rocket::Outcome::Forward(()),
66            };
67        });
68
69        constructor.push(quote! { #name: #name })
70    }
71
72    let trait_implementation = quote! {
73        impl<'a, 'r> ::rocket::request::FromRequest<'a, 'r> for #name {
74            #error_type_declaration
75
76            fn from_request(request: &'a ::rocket::Request<'r>) -> ::rocket::Outcome<Self, (::rocket::http::Status, Self::Error), ()> {
77                #(#arms)*
78                ::rocket::Outcome::Success(#name { #(#constructor),*})
79            }
80        }
81    };
82    Ok(trait_implementation.into())
83}
84
85fn get_error_type(attrs: &[Attribute]) -> Result<Option<Type>, Error> {
86    let mut error_type_decls = attrs
87        .iter()
88        .filter_map(|attr| attr.parse_meta().ok().map(|meta| (attr, meta)))
89        .filter(|(_attr, meta)| meta.name() == ERROR_TYPE_ATTRIBUTE)
90        .collect::<Vec<_>>();
91
92    let error_type_attr_meta = match (error_type_decls.pop(), error_type_decls.pop()) {
93        (None, _) => return Ok(None),
94        (Some((_attr, meta)), None) => meta,
95        (Some((attr, _meta)), Some(_)) => {
96            return Err(Error::new_spanned(
97                attr,
98                format!("Found more than one `{}` declaration", ERROR_TYPE_ATTRIBUTE),
99            ));
100        }
101    };
102
103    let name_value = if let Meta::NameValue(name_value) = error_type_attr_meta {
104        name_value
105    } else {
106        return Err(Error::new_spanned(
107            error_type_attr_meta,
108            format!(
109                "Expected a name-value attribute, e.g. `#[{} = \"MyType\"]`",
110                ERROR_TYPE_ATTRIBUTE
111            ),
112        ));
113    };
114
115    match name_value.lit {
116        Lit::Str(lit_str) => match lit_str.parse() {
117            Ok(type_spec) => Ok(Some(type_spec)),
118            Err(_) => Err(Error::new_spanned(lit_str, "Invalid type specifier")),
119        },
120        other => Err(Error::new_spanned(other, "Invalid string literal")),
121    }
122}