rocket_dependency_injection_derive/
lib.rs1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4use syn::{Block, DeriveInput, Error};
5
6extern crate proc_macro;
7
8#[proc_macro_derive(Resolve)]
9pub fn derive(input: TokenStream) -> TokenStream {
10 let derive_input: DeriveInput = syn::parse(input).unwrap();
11
12 let name = &derive_input.ident;
13 let (impl_generics, ty_generics, where_clause) = derive_input.generics.split_for_impl();
14
15 quote::quote! {
16 impl #impl_generics rocket_dependency_injection::Resolve for #name #ty_generics #where_clause {
17 fn resolve(service_provider: &rocket_dependency_injection::ServiceProvider) -> Self {
18 Self::resolve_generated(service_provider)
19 }
20 }
21 }
22 .into()
23}
24
25#[proc_macro_attribute]
26pub fn resolve_constructor(_: TokenStream, input: TokenStream) -> TokenStream {
27 let function: syn::ItemFn = syn::parse(input).unwrap();
28
29 validate_function(&function).unwrap();
30
31 let resolve_generated = generate_resolve_generated(&function.sig);
32
33 quote! {
34 #function
35
36 #resolve_generated
37 }
38 .into()
39}
40
41fn generate_resolve_generated(sig: &syn::Signature) -> proc_macro2::TokenStream {
42 let fn_name = sig.ident.clone();
43 let calls: Vec<proc_macro2::TokenStream> = sig
44 .inputs
45 .iter()
46 .map(|_| quote!(service_provider.unwrap()))
47 .collect();
48
49 quote! {
50 #[allow(unused_variables)]
51 pub fn resolve_generated(service_provider: &rocket_dependency_injection::ServiceProvider) -> Self {
52 Self::#fn_name(#(#calls),*)
53 }
54 }
55}
56
57fn validate_function(function: &syn::ItemFn) -> syn::Result<()> {
58 validate_sig(&function.sig)?;
59 validate_block(&function.block)?;
60 Ok(())
61}
62
63fn validate_block(block: &Block) -> syn::Result<()> {
64 match block.stmts.len() {
65 0 => Err(Error::new(
66 Span::call_site(),
67 "resolve_constructor cannot be set on a function with empty block",
68 )),
69 _ => Ok(()),
70 }
71}
72
73fn validate_sig(sig: &syn::Signature) -> syn::Result<()> {
74 validate_output(&sig.output)?;
75 validate_arguments(&sig.inputs)?;
76 if sig.asyncness.is_some() {
77 Err(Error::new(
78 Span::call_site(),
79 "resolve_constructor cannot be set on an async function",
80 ))
81 } else if sig.constness.is_some() {
82 Err(Error::new(
83 Span::call_site(),
84 "resolve_constructor cannot be set on a const function",
85 ))
86 } else {
87 Ok(())
88 }
89}
90
91fn validate_arguments(
92 inputs: &syn::punctuated::Punctuated<syn::FnArg, syn::token::Comma>,
93) -> syn::Result<()> {
94 inputs.iter().all(|arg| match arg {
95 syn::FnArg::Typed(typ) if !is_self_type(&typ.ty) => true,
96 _ => false
97 }).then_some(()).ok_or(Error::new(Span::call_site(), "resolve_constructor cannot be set on a function that have an argument of type Self (either self or Self)"))
98}
99
100fn validate_output(output: &syn::ReturnType) -> syn::Result<()> {
101 match output {
102 syn::ReturnType::Type(_, typ) if is_self_type(typ) => Ok(()),
103 _ => Err(Error::new(
104 Span::call_site(),
105 "resolve_constructor cannot be set on a function that does not return Self",
106 )),
107 }
108}
109
110fn is_self_type(typ: &Box<syn::Type>) -> bool {
111 match &**typ {
112 syn::Type::Path(path) => match path.path.segments.len() {
113 1 => match path
114 .path
115 .segments
116 .last()
117 .unwrap()
118 .ident
119 .to_string()
120 .as_str()
121 {
122 "Self" => true,
123 _ => false,
124 },
125 _ => false,
126 },
127 _ => false,
128 }
129}