coi_actix_web_derive/
lib.rs

1//! Provides the `inject` proc macro for use by the [`coi-actix-web`] crate.
2//!
3//! [`coi-actix-web`]: https://docs.rs/coi-actix-web
4
5extern crate proc_macro;
6use crate::attr::Inject;
7use proc_macro::TokenStream;
8use quote::{format_ident, quote};
9use syn::{
10    parse_macro_input, parse_quote, Error, FnArg, GenericArgument, Ident, ItemFn, Pat,
11    PathArguments, Result, Type, TypePath,
12};
13
14mod attr;
15mod symbols;
16
17fn get_arc_ty(ty: &Type, type_path: &TypePath) -> Result<Type> {
18    let make_arc_error = || Err(Error::new_spanned(ty, "only Arc<...> can be injected"));
19    if type_path.path.leading_colon.is_some() || type_path.path.segments.len() != 1 {
20        return make_arc_error();
21    }
22    let segment = &type_path.path.segments[0];
23    if segment.ident != "Arc" {
24        return make_arc_error();
25    }
26    let angle_args = match &segment.arguments {
27        PathArguments::AngleBracketed(angle_args) => angle_args,
28        _ => return make_arc_error(),
29    };
30    let args = &angle_args.args;
31    if args.len() != 1 {
32        return make_arc_error();
33    }
34
35    if let GenericArgument::Type(ty) = &args[0] {
36        Ok(ty.clone())
37    } else {
38        make_arc_error()
39    }
40}
41
42/// The #[inject] proc macro should only be applied to functions that will
43/// be passed to [`actix-web`]'s routing APIs.
44///
45/// [`actix-web`]: https://docs.rs/actix-web
46///
47/// ## Examples
48/// ```rust,no_run
49/// use actix_web::Responder;
50/// use coi::Inject;
51/// use coi_actix_web::inject;
52///
53/// # trait IService : Inject {}
54///
55/// #[inject]
56/// async fn get_all(#[inject] service: Arc<dyn IService>) -> Result<impl Responder, ()> {
57///     //...
58///     Ok("Hello, World")
59/// }
60/// ```
61///
62/// This proc macro changes the input arguments to the fn that it's applied to. All `#[inject]` args
63/// get collected into a single type and are pattern matched out. This is to take advantage of the
64/// [`coi-actix-web`] crate's `FromResponse` impls. By ensuring that all injected types are part of
65/// the same type, we can guarantee that all injected types are resolved from the same scoped
66/// container. The downside of this is that the signature you see is not what is generated, and
67/// this makes manually calling these functions more verbose. Since all of these functions are
68/// expected to be passed to [`actix-web`]'s routing APIs, it's not an issue since those are all
69/// generic.
70///
71/// [`coi-actix-web`]: https://docs.rs/coi-actix-web
72/// [`actix-web`]: https://docs.rs/actix-web
73#[proc_macro_attribute]
74pub fn inject(attr: TokenStream, input: TokenStream) -> TokenStream {
75    let attr = parse_macro_input!(attr as Inject);
76    let caw = attr.crate_path;
77
78    let mut input = parse_macro_input!(input as ItemFn);
79    let fn_ident = input.sig.ident.clone();
80    let sig = &mut input.sig;
81    let inputs = &mut sig.inputs;
82    let mut args = vec![];
83    while !inputs.is_empty() {
84        if let Some(arg) = inputs.pop() {
85            args.push(arg);
86        }
87    }
88    args.reverse();
89    let (inject, not_inject): (Vec<_>, Vec<_>) =
90        args.into_iter().partition(|arg| match arg.value() {
91            FnArg::Typed(arg) => arg.attrs.iter().any(|attr| attr.path.is_ident("inject")),
92            _ => false,
93        });
94
95    for arg in not_inject {
96        let (arg, punct) = arg.into_tuple();
97        inputs.push_value(arg);
98        if let Some(punct) = punct {
99            inputs.push_punct(punct);
100        }
101    }
102
103    let num_args = inject.len();
104    let (key, ty): (Vec<Result<Ident>>, Vec<Result<Type>>) = inject
105        .into_iter()
106        .map(|arg| match arg.value() {
107            FnArg::Typed(arg) => {
108                let pat = match &*arg.pat {
109                    Pat::Ident(pat_ident) => {
110                        let ident = &pat_ident.ident;
111                        Ok(parse_quote! { #ident })
112                    }
113                    _ => Err(Error::new_spanned(&*arg.pat, "patterns cannot be injected")),
114                };
115
116                let ty = if let Type::Path(type_path) = &*arg.ty {
117                    get_arc_ty(&*arg.ty, type_path)
118                } else {
119                    Err(Error::new_spanned(
120                        &*arg.ty,
121                        "only Arc<...> can be injected",
122                    ))
123                };
124                (pat, ty)
125            }
126            _ => unreachable!(),
127        })
128        .unzip();
129    let key = match key.into_iter().collect::<Result<Vec<_>>>() {
130        Ok(key) => key,
131        Err(e) => return e.to_compile_error().into(),
132    };
133    let ty = match ty.into_iter().collect::<Result<Vec<_>>>() {
134        Ok(ty) => ty,
135        Err(e) => return e.to_compile_error().into(),
136    };
137    let (defs, container_key): (Vec<_>, Vec<_>) = key
138        .iter()
139        .zip(ty.iter())
140        .map(|(key, ty)| {
141            let ident = format_ident!("__{}_{}_Key", fn_ident, key);
142            let key_str = format!("{}", key);
143            (
144                quote! {
145                    #[allow(non_camel_case_types)]
146                    struct #ident;
147                    impl #caw::ContainerKey<#ty> for #ident {
148                        const KEY: &'static str = #key_str;
149                    }
150                },
151                ident,
152            )
153        })
154        .unzip();
155
156    let injected_arg = if num_args > 1 {
157        parse_quote! {
158            #caw::Injected((#( #key, )*), _):
159            #caw::Injected<(#( ::std::sync::Arc<#ty>, )*), (#( #container_key, )*)>
160        }
161    } else {
162        parse_quote! {
163            #caw::Injected(#( #key, )* _):
164            #caw::Injected<#( ::std::sync::Arc<#ty>, )* #( #container_key, )*>
165        }
166    };
167    inputs.push(injected_arg);
168
169    let expanded = quote! {
170        #( #defs )*
171        #input
172    };
173    TokenStream::from(expanded)
174}