rocket_extra_codegen/
lib.rs1#![recursion_limit = "128"]
2
3extern crate proc_macro;
4
5static ERROR_TYPE_ATTRIBUTE: &str = "error_type";
6
7use crate::proc_macro::TokenStream;
8use quote::quote;
9use syn;
10use syn::Attribute;
11use syn::Data;
12use syn::DeriveInput;
13use syn::Error;
14use syn::Fields;
15use syn::Lit;
16use syn::Meta;
17use syn::Type;
18
19#[proc_macro_derive(FromRequest, attributes(error_type))]
20pub fn derive_from_request(input: TokenStream) -> TokenStream {
21 try_derive_from_request(input).unwrap_or_else(|err| err.to_compile_error().into())
22}
23
24fn try_derive_from_request(input: TokenStream) -> Result<TokenStream, Error> {
25 let ast = syn::parse::<DeriveInput>(input)?;
26
27 let name = ast.ident;
28 let fields = match ast.data {
29 Data::Struct(struct_) => struct_.fields,
30 Data::Enum(enum_) => return Err(Error::new_spanned(enum_.enum_token, "Should be a struct")),
31 Data::Union(union_) => {
32 return Err(Error::new_spanned(union_.union_token, "Should be a struct"));
33 }
34 };
35
36 if !ast.generics.params.is_empty() {
37 return Err(Error::new_spanned(
38 ast.generics,
39 "Generics are not yet supported",
40 ));
41 }
42
43 if let Fields::Unnamed(unnamed) = fields {
44 return Err(Error::new_spanned(
45 unnamed,
46 "Should be a struct with named fields",
47 ));
48 }
49
50 let error_type_declaration = match get_error_type(&ast.attrs)? {
51 Some(error_type) => quote! { type Error = #error_type; },
52 None => quote! { type Error = (); },
53 };
54
55 let mut arms = Vec::new();
56 let mut constructor = Vec::new();
57
58 for field in &fields {
59 let name = field.ident.as_ref().expect("Unexpected unnamed field");
60 let ty = &field.ty;
61 arms.push(quote! {
62 let #name = match ::rocket::Request::guard::<#ty>(request) {
63 ::rocket::Outcome::Success(user) => user,
64 ::rocket::Outcome::Failure((status, error)) => return ::rocket::Outcome::Failure((status, ::std::convert::From::from(error))),
65 ::rocket::Outcome::Forward(()) => return ::rocket::Outcome::Forward(()),
66 };
67 });
68
69 constructor.push(quote! { #name: #name })
70 }
71
72 let trait_implementation = quote! {
73 impl<'a, 'r> ::rocket::request::FromRequest<'a, 'r> for #name {
74 #error_type_declaration
75
76 fn from_request(request: &'a ::rocket::Request<'r>) -> ::rocket::Outcome<Self, (::rocket::http::Status, Self::Error), ()> {
77 #(#arms)*
78 ::rocket::Outcome::Success(#name { #(#constructor),*})
79 }
80 }
81 };
82 Ok(trait_implementation.into())
83}
84
85fn get_error_type(attrs: &[Attribute]) -> Result<Option<Type>, Error> {
86 let mut error_type_decls = attrs
87 .iter()
88 .filter_map(|attr| attr.parse_meta().ok().map(|meta| (attr, meta)))
89 .filter(|(_attr, meta)| meta.name() == ERROR_TYPE_ATTRIBUTE)
90 .collect::<Vec<_>>();
91
92 let error_type_attr_meta = match (error_type_decls.pop(), error_type_decls.pop()) {
93 (None, _) => return Ok(None),
94 (Some((_attr, meta)), None) => meta,
95 (Some((attr, _meta)), Some(_)) => {
96 return Err(Error::new_spanned(
97 attr,
98 format!("Found more than one `{}` declaration", ERROR_TYPE_ATTRIBUTE),
99 ));
100 }
101 };
102
103 let name_value = if let Meta::NameValue(name_value) = error_type_attr_meta {
104 name_value
105 } else {
106 return Err(Error::new_spanned(
107 error_type_attr_meta,
108 format!(
109 "Expected a name-value attribute, e.g. `#[{} = \"MyType\"]`",
110 ERROR_TYPE_ATTRIBUTE
111 ),
112 ));
113 };
114
115 match name_value.lit {
116 Lit::Str(lit_str) => match lit_str.parse() {
117 Ok(type_spec) => Ok(Some(type_spec)),
118 Err(_) => Err(Error::new_spanned(lit_str, "Invalid type specifier")),
119 },
120 other => Err(Error::new_spanned(other, "Invalid string literal")),
121 }
122}