coi_rocket_derive/
lib.rs

1//! Provides the `inject` proc macro for use by the [`coi-rocket`] crate.
2//!
3//! [`coi-rocket`]: https://docs.rs/coi-rocket
4
5extern crate proc_macro;
6use crate::{
7    attr::Inject,
8    ctxt::Ctxt,
9    symbols::{ARC, INJECT},
10};
11use proc_macro::TokenStream;
12use quote::{format_ident, quote};
13use syn::{
14    parse_macro_input, parse_quote, Error, FnArg, GenericArgument, Ident, ItemFn, Pat,
15    PathArguments, Result, Type, TypePath,
16};
17
18mod attr;
19mod ctxt;
20mod symbols;
21
22fn get_arc_ty(ty: &Type, type_path: &TypePath) -> Result<Type> {
23    let make_arc_error = || Err(Error::new_spanned(ty, "only Arc<...> can be injected"));
24    if type_path.path.leading_colon.is_some() || type_path.path.segments.len() != 1 {
25        return make_arc_error();
26    }
27    let segment = &type_path.path.segments[0];
28    if segment.ident != ARC {
29        return make_arc_error();
30    }
31    let angle_args = match &segment.arguments {
32        PathArguments::AngleBracketed(angle_args) => angle_args,
33        _ => return make_arc_error(),
34    };
35    let args = &angle_args.args;
36    if args.len() != 1 {
37        return make_arc_error();
38    }
39
40    if let GenericArgument::Type(ty) = &args[0] {
41        Ok(ty.clone())
42    } else {
43        make_arc_error()
44    }
45}
46
47/// The #[inject] proc macro should only be applied to functions that will
48/// be passed to [`rocket`]'s routing APIs.
49///
50/// [`rocket`]: https://rocket.rs
51///
52/// ## Examples
53/// ```rust,no_run
54/// #![feature(decl_macro)]
55///
56/// use coi::Inject;
57/// use coi_rocket::inject;
58/// use rocket::get;
59/// use std::sync::Arc;
60///
61/// # trait IService : Inject {}
62///
63/// #[inject]
64/// #[get("/path")]
65/// fn get_all(#[inject] service: Arc<dyn IService>) -> String {
66///     // use service here...
67///     String::from("Hello, World")
68/// }
69/// ```
70#[proc_macro_attribute]
71pub fn inject(attr: TokenStream, input: TokenStream) -> TokenStream {
72    let attr = parse_macro_input!(attr as Inject);
73    let cr = attr.crate_path;
74
75    let mut input = parse_macro_input!(input as ItemFn);
76    let fn_ident = input.sig.ident.clone();
77    let sig = &mut input.sig;
78    let mut defs = vec![];
79    let mut stmts = vec![];
80    let mut ctxt = Ctxt::new();
81    for arg in &mut sig.inputs {
82        if let FnArg::Typed(arg) = arg {
83            if arg.attrs.iter().any(|attr| attr.path() == INJECT) {
84                arg.attrs.retain(|attr| attr.path() != INJECT);
85                let key: Ident = if let Pat::Ident(pat_ident) = &*arg.pat {
86                    let ident = &pat_ident.ident;
87                    parse_quote! { #ident }
88                } else {
89                    ctxt.push_spanned(&*arg.pat, "patterns cannot be injected");
90                    continue;
91                };
92
93                let arc_ty = &*arg.ty;
94                let ty = if let Type::Path(type_path) = &*arg.ty {
95                    match get_arc_ty(&arg.ty, type_path) {
96                        Ok(ty) => ty,
97                        Err(e) => {
98                            ctxt.push_spanned(&*arg.ty, e);
99                            continue;
100                        }
101                    }
102                } else {
103                    ctxt.push_spanned(&*arg.ty, "only Arc<...> can be injected");
104                    continue;
105                };
106
107                let ident = format_ident!("__{}_{}_Key", fn_ident, key);
108                let key_str = format!("{}", key);
109                defs.push(quote! {
110                    #[allow(non_camel_case_types)]
111                    struct #ident;
112                    impl #cr::ContainerKey<#ty> for #ident {
113                        const KEY: &'static str = #key_str;
114                    }
115                });
116
117                stmts.push(parse_quote!( let #cr::Injected(#key, _) = #key; ));
118                *arg.ty = parse_quote!( #cr::Injected<#arc_ty, #ident> );
119            }
120        }
121    }
122
123    input.block.stmts = stmts.into_iter().chain(input.block.stmts).collect();
124
125    if let Err(e) = ctxt.check() {
126        let compile_errors = e.iter().map(Error::to_compile_error);
127        return quote!(#( #compile_errors )*).into();
128    }
129
130    let expanded = quote! {
131        #( #defs )*
132        #input
133    };
134    TokenStream::from(expanded)
135}