rocket_dependency_injection_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4use syn::{Block, DeriveInput, Error};
5
6extern crate proc_macro;
7
8#[proc_macro_derive(Resolve)]
9pub fn derive(input: TokenStream) -> TokenStream {
10    let derive_input: DeriveInput = syn::parse(input).unwrap();
11
12    let name = &derive_input.ident;
13    let (impl_generics, ty_generics, where_clause) = derive_input.generics.split_for_impl();
14
15    quote::quote! {
16        impl #impl_generics rocket_dependency_injection::Resolve for #name #ty_generics #where_clause {
17            fn resolve(service_provider: &rocket_dependency_injection::ServiceProvider) -> Self {
18                Self::resolve_generated(service_provider)
19            }
20        }
21    }
22    .into()
23}
24
25#[proc_macro_attribute]
26pub fn resolve_constructor(_: TokenStream, input: TokenStream) -> TokenStream {
27    let function: syn::ItemFn = syn::parse(input).unwrap();
28
29    validate_function(&function).unwrap();
30
31    let resolve_generated = generate_resolve_generated(&function.sig);
32
33    quote! {
34        #function
35
36        #resolve_generated
37    }
38    .into()
39}
40
41fn generate_resolve_generated(sig: &syn::Signature) -> proc_macro2::TokenStream {
42    let fn_name = sig.ident.clone();
43    let calls: Vec<proc_macro2::TokenStream> = sig
44        .inputs
45        .iter()
46        .map(|_| quote!(service_provider.unwrap()))
47        .collect();
48
49    quote! {
50        #[allow(unused_variables)]
51        pub fn resolve_generated(service_provider: &rocket_dependency_injection::ServiceProvider) -> Self {
52            Self::#fn_name(#(#calls),*)
53        }
54    }
55}
56
57fn validate_function(function: &syn::ItemFn) -> syn::Result<()> {
58    validate_sig(&function.sig)?;
59    validate_block(&function.block)?;
60    Ok(())
61}
62
63fn validate_block(block: &Block) -> syn::Result<()> {
64    match block.stmts.len() {
65        0 => Err(Error::new(
66            Span::call_site(),
67            "resolve_constructor cannot be set on a function with empty block",
68        )),
69        _ => Ok(()),
70    }
71}
72
73fn validate_sig(sig: &syn::Signature) -> syn::Result<()> {
74    validate_output(&sig.output)?;
75    validate_arguments(&sig.inputs)?;
76    if sig.asyncness.is_some() {
77        Err(Error::new(
78            Span::call_site(),
79            "resolve_constructor cannot be set on an async function",
80        ))
81    } else if sig.constness.is_some() {
82        Err(Error::new(
83            Span::call_site(),
84            "resolve_constructor cannot be set on a const function",
85        ))
86    } else {
87        Ok(())
88    }
89}
90
91fn validate_arguments(
92    inputs: &syn::punctuated::Punctuated<syn::FnArg, syn::token::Comma>,
93) -> syn::Result<()> {
94    inputs.iter().all(|arg| match arg {
95        syn::FnArg::Typed(typ) if !is_self_type(&typ.ty) => true,
96        _ => false
97    }).then_some(()).ok_or(Error::new(Span::call_site(), "resolve_constructor cannot be set on a function that have an argument of type Self (either self or Self)"))
98}
99
100fn validate_output(output: &syn::ReturnType) -> syn::Result<()> {
101    match output {
102        syn::ReturnType::Type(_, typ) if is_self_type(typ) => Ok(()),
103        _ => Err(Error::new(
104            Span::call_site(),
105            "resolve_constructor cannot be set on a function that does not return Self",
106        )),
107    }
108}
109
110fn is_self_type(typ: &Box<syn::Type>) -> bool {
111    match &**typ {
112        syn::Type::Path(path) => match path.path.segments.len() {
113            1 => match path
114                .path
115                .segments
116                .last()
117                .unwrap()
118                .ident
119                .to_string()
120                .as_str()
121            {
122                "Self" => true,
123                _ => false,
124            },
125            _ => false,
126        },
127        _ => false,
128    }
129}