1extern crate proc_macro;
6use crate::{
7 attr::Inject,
8 ctxt::Ctxt,
9 symbols::{ARC, INJECT},
10};
11use proc_macro::TokenStream;
12use quote::{format_ident, quote};
13use syn::{
14 parse_macro_input, parse_quote, Error, FnArg, GenericArgument, Ident, ItemFn, Pat,
15 PathArguments, Result, Type, TypePath,
16};
17
18mod attr;
19mod ctxt;
20mod symbols;
21
22fn get_arc_ty(ty: &Type, type_path: &TypePath) -> Result<Type> {
23 let make_arc_error = || Err(Error::new_spanned(ty, "only Arc<...> can be injected"));
24 if type_path.path.leading_colon.is_some() || type_path.path.segments.len() != 1 {
25 return make_arc_error();
26 }
27 let segment = &type_path.path.segments[0];
28 if segment.ident != ARC {
29 return make_arc_error();
30 }
31 let angle_args = match &segment.arguments {
32 PathArguments::AngleBracketed(angle_args) => angle_args,
33 _ => return make_arc_error(),
34 };
35 let args = &angle_args.args;
36 if args.len() != 1 {
37 return make_arc_error();
38 }
39
40 if let GenericArgument::Type(ty) = &args[0] {
41 Ok(ty.clone())
42 } else {
43 make_arc_error()
44 }
45}
46
47#[proc_macro_attribute]
71pub fn inject(attr: TokenStream, input: TokenStream) -> TokenStream {
72 let attr = parse_macro_input!(attr as Inject);
73 let cr = attr.crate_path;
74
75 let mut input = parse_macro_input!(input as ItemFn);
76 let fn_ident = input.sig.ident.clone();
77 let sig = &mut input.sig;
78 let mut defs = vec![];
79 let mut stmts = vec![];
80 let mut ctxt = Ctxt::new();
81 for arg in &mut sig.inputs {
82 if let FnArg::Typed(arg) = arg {
83 if arg.attrs.iter().any(|attr| attr.path() == INJECT) {
84 arg.attrs.retain(|attr| attr.path() != INJECT);
85 let key: Ident = if let Pat::Ident(pat_ident) = &*arg.pat {
86 let ident = &pat_ident.ident;
87 parse_quote! { #ident }
88 } else {
89 ctxt.push_spanned(&*arg.pat, "patterns cannot be injected");
90 continue;
91 };
92
93 let arc_ty = &*arg.ty;
94 let ty = if let Type::Path(type_path) = &*arg.ty {
95 match get_arc_ty(&arg.ty, type_path) {
96 Ok(ty) => ty,
97 Err(e) => {
98 ctxt.push_spanned(&*arg.ty, e);
99 continue;
100 }
101 }
102 } else {
103 ctxt.push_spanned(&*arg.ty, "only Arc<...> can be injected");
104 continue;
105 };
106
107 let ident = format_ident!("__{}_{}_Key", fn_ident, key);
108 let key_str = format!("{}", key);
109 defs.push(quote! {
110 #[allow(non_camel_case_types)]
111 struct #ident;
112 impl #cr::ContainerKey<#ty> for #ident {
113 const KEY: &'static str = #key_str;
114 }
115 });
116
117 stmts.push(parse_quote!( let #cr::Injected(#key, _) = #key; ));
118 *arg.ty = parse_quote!( #cr::Injected<#arc_ty, #ident> );
119 }
120 }
121 }
122
123 input.block.stmts = stmts.into_iter().chain(input.block.stmts).collect();
124
125 if let Err(e) = ctxt.check() {
126 let compile_errors = e.iter().map(Error::to_compile_error);
127 return quote!(#( #compile_errors )*).into();
128 }
129
130 let expanded = quote! {
131 #( #defs )*
132 #input
133 };
134 TokenStream::from(expanded)
135}