anthill_di_derive/
lib.rs

1#![crate_type = "proc-macro"]
2#![recursion_limit = "192"]
3
4extern crate proc_macro;
5extern crate proc_macro2;
6#[macro_use]
7extern crate quote;
8extern crate syn;
9
10use proc_macro::TokenStream;
11use proc_macro2::TokenStream as TokenStream2;
12use syn::Token;
13
14macro_rules! my_quote {
15    ($($t:tt)*) => (quote_spanned!(proc_macro2::Span::call_site() => $($t)*))
16}
17
18#[proc_macro_derive(constructor, attributes(custom_resolve, resolve, resolve_collection, resolve_by_component, ioc_context))]
19pub fn derive(input: TokenStream) -> TokenStream {
20    let ast: syn::DeriveInput = syn::parse(input).expect("Couldn't parse item");
21    let result = match ast.data {
22        syn::Data::Enum(ref e) => panic!("doesn't work with enum yet"),
23        syn::Data::Struct(ref s) => new_for_struct(&ast, &s.fields),
24        syn::Data::Union(_) => panic!("doesn't work with unions yet"),
25    };
26    result.into()
27}
28
29fn new_for_struct(
30    ast: &syn::DeriveInput,
31    fields: &syn::Fields,
32) -> proc_macro2::TokenStream {
33    match *fields {
34        syn::Fields::Named(ref fields) => new_impl(&ast, Some(&fields.named)),
35        syn::Fields::Unit => panic!("doesn't work with unit yet"),
36        syn::Fields::Unnamed(_) => panic!("doesn't work with unnamed yet"),
37    }
38}
39
40fn new_impl(
41    ast: &syn::DeriveInput,
42    fields: Option<&syn::punctuated::Punctuated<syn::Field, Token![,]>>,
43) -> proc_macro2::TokenStream {
44    let name = &ast.ident;
45    let empty = Default::default();
46    let fields: Vec<_> = fields
47        .unwrap_or(&empty)
48        .iter()
49        .enumerate()
50        .map(|(i, f)| FieldExt::new(f, i))
51        .collect();
52
53    let assigns = fields.iter().filter(|a| !a.is_ioc_context()).map(|f| f.as_assign());
54    let assigns = my_quote![#(#assigns);*]; // ;
55
56    let inits = fields.iter().map(|f| f.as_init());
57    let inits = my_quote![#(#inits),*];
58
59    let ioc_context_init = fields.iter().filter(|a| a.is_ioc_context()).map(|f| f.as_assign());
60    let ioc_context_init = my_quote![#(#ioc_context_init)*];
61
62    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
63
64    #[cfg(not(feature = "async-mode"))]
65    my_quote! {
66        impl #impl_generics anthill_di::Constructor for #name #ty_generics #where_clause {
67            fn ctor(ctx: anthill_di::DependencyContext) -> anthill_di::types::BuildDependencyResult<Self> {
68                let ctx = ctx;
69                    
70                #assigns;
71                #ioc_context_init;
72
73                Ok(#name {#inits} )
74            }
75        }
76    }
77
78    #[cfg(feature = "async-mode")]
79    my_quote! {
80        impl #impl_generics anthill_di::Constructor for #name #ty_generics #where_clause {
81            fn ctor<'async_trait>(ctx: anthill_di::DependencyContext) -> std::pin::Pin<Box<dyn std::future::Future<Output = anthill_di::types::BuildDependencyResult<Self>> + core::marker::Send + core::marker::Sync + 'async_trait>> where Self: 'async_trait {
82                Box::pin(async move {
83                    let ctx = ctx;
84                    
85                    #assigns;
86                    #ioc_context_init;
87
88                    Ok(#name {#inits} )
89                })
90            }
91        }
92    }
93}
94
95struct FieldExt<'a> {
96    ty: &'a syn::Type,
97    attr: Option<FieldAttr>,
98    ident: syn::Ident,
99}
100
101impl<'a> std::fmt::Debug for FieldExt<'a> {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        f.debug_struct("FieldExt").field("attr", &self.attr).field("ident", &self.ident).finish()
104    }
105}
106
107impl<'a> FieldExt<'a> {
108    pub fn new(field: &'a syn::Field, idx: usize) -> FieldExt<'a> {
109        FieldExt {
110            ty: &field.ty,
111            attr: FieldAttr::parse(&field.attrs),
112            ident: field.ident.clone().unwrap(),
113        }
114    }
115
116    pub fn is_phantom_data(&self) -> bool {
117        match *self.ty {
118            syn::Type::Path(syn::TypePath {
119                qself: None,
120                ref path,
121            }) => path
122                .segments
123                .last()
124                .map(|x| x.ident == "PhantomData")
125                .unwrap_or(false),
126            _ => false,
127        }
128    }
129
130    pub fn is_ioc_context(&self) -> bool {
131        if let Some (attr) = &self.attr {
132            if let FieldAttr::IocContext = attr {
133                return true;
134            }
135        }
136
137        false
138    }
139
140    pub fn as_assign(&self) -> proc_macro2::TokenStream {
141        let f_name = &self.ident;
142
143        let init = if self.is_phantom_data() {
144            my_quote!(::std::marker::PhantomData)
145        } else {
146            match self.attr {
147                #[cfg(not(feature = "async-mode"))]
148                None => my_quote!(ctx.resolve()?),
149                #[cfg(feature = "async-mode")]
150                None => my_quote!(ctx.resolve().await?),
151                Some(ref attr) => attr.as_tokens(),
152            }
153        };
154
155        my_quote!(let #f_name = #init)
156    }
157
158    pub fn as_init(&self) -> proc_macro2::TokenStream {
159        let f_name = &self.ident;
160        my_quote!(#f_name)
161    }
162}
163
164#[derive(Debug)]
165enum FieldAttr {
166    IocContext,
167    Resolve,
168    ResolveCollection,
169    ResolveByComponent(proc_macro2::Ident),
170    Value(proc_macro2::TokenStream),
171}
172
173impl FieldAttr {
174    pub fn as_tokens(&self) -> proc_macro2::TokenStream {
175        match *self {
176            FieldAttr::IocContext => my_quote!(ctx),
177            #[cfg(not(feature = "async-mode"))]
178            FieldAttr::Resolve => my_quote!(ctx.resolve()?),
179            #[cfg(feature = "async-mode")]
180            FieldAttr::Resolve => my_quote!(ctx.resolve().await?),
181            #[cfg(not(feature = "async-mode"))]
182            FieldAttr::ResolveCollection => my_quote!(ctx.resolve_collection()?),
183            #[cfg(feature = "async-mode")]
184            FieldAttr::ResolveCollection => my_quote!(ctx.resolve_collection().await?),
185            #[cfg(not(feature = "async-mode"))]
186            FieldAttr::ResolveByComponent(ref s) => my_quote!(ctx.resolve_by_type_id(std::any::TypeId::of::<#s>())?),
187            #[cfg(feature = "async-mode")]
188            FieldAttr::ResolveByComponent(ref s) => my_quote!(ctx.resolve_by_type_id(std::any::TypeId::of::<#s>()).await?),
189            FieldAttr::Value(ref s) => my_quote!(#s),
190        }
191    }
192
193    pub fn parse(attrs: &[syn::Attribute]) -> Option<FieldAttr> {
194        use syn::{AttrStyle, Meta, NestedMeta};
195
196        //let mut result = None;
197        for attr in attrs.iter() {
198            match attr.style {
199                AttrStyle::Outer => {}
200                _ => continue,
201            }
202            let last_attr_path = attr
203                .path
204                .segments
205                .iter()
206                .last()
207                .expect("Expected at least one segment where #[segment[::segment*](..)]");
208
209            if (*last_attr_path).ident != "ioc_context" &&
210                (*last_attr_path).ident != "resolve" &&
211                (*last_attr_path).ident != "resolve_collection" &&
212                (*last_attr_path).ident != "resolve_by_component" &&
213                (*last_attr_path).ident != "custom_resolve" {
214                continue;
215            }
216            let meta = match attr.parse_meta() {
217                Ok(meta) => meta,
218                Err(_) => continue,
219            };
220
221            if meta.path().is_ident("ioc_context") {
222                match (meta) {
223                    Meta::Path(_) => {
224                        return Some(FieldAttr::IocContext)    
225                    }
226                    _ => panic!("Invalid #[ioc_context] attribute: #[ioc_context{}]", path_to_string(&meta.path())),
227                }
228            }
229 
230            if meta.path().is_ident("resolve") {
231                //if ()
232                match (meta) {
233                    Meta::Path(_) => {
234                        return Some(FieldAttr::Resolve)    
235                    }
236                    _ => panic!("Invalid #[resolve] attribute: #[resolve{}]", path_to_string(&meta.path())),
237                }
238            }
239
240            if meta.path().is_ident("resolve_collection") {
241                match (meta) {
242                    Meta::Path(_) => {
243                        return Some(FieldAttr::ResolveCollection)    
244                    }
245                    _ => panic!("Invalid #[resolve_collection] attribute: #[{}]", path_to_string(&meta.path())),
246                }
247            }
248
249            if meta.path().is_ident("resolve_by_component") {
250                match (meta) {
251                    Meta::List(list) => {//
252                        match list.nested.iter().nth(0).expect("resolve_by_component attribute required 1 element") {
253                            NestedMeta::Meta(Meta::Path(ref path)) => {
254                                let ident = path.segments.iter().nth(0).unwrap().clone().ident;
255                                //let r = path.get_ident();
256                                return Some(FieldAttr::ResolveByComponent(ident));
257                            },
258                            _ => panic!("Invalid #[resolve_by_component] attribute")
259                        }
260                        
261                    }
262                    _ => panic!("Invalid #[resolve_by_component] attribute: #[resolve_by_component({})]", path_to_string(&meta.path())),
263                }
264            }
265
266            if meta.path().is_ident("custom_resolve") {
267                match (meta) {
268                    Meta::List(list) => {
269                        match list.nested.iter().nth(0).expect("custom_resolve attribute required 1 element") {
270                            NestedMeta::Meta(Meta::NameValue(ref kv)) => {
271                                if let syn::Lit::Str(ref s) = kv.lit {
272                                    if kv.path.is_ident("value") {
273                                        let tokens = lit_str_to_token_stream(s).ok().expect(&format!(
274                                            "Invalid expression in #[custom_resolve]: `{}`",
275                                            s.value()
276                                        ));
277                                        return Some(FieldAttr::Value(tokens));
278                                    } else {
279                                        panic!("Invalid #[custom_resolve] attribute: #[custom_resolve({} = ..)]", path_to_string(&kv.path));
280                                    }
281                                } else {
282                                    panic!("Non-string literal value in #[custom_resolve] attribute");
283                                }
284                            },
285                            _ => panic!("Non-string literal value in #[custom_resolve] attribute"),
286                        }
287                    }
288                    _ => panic!("Invalid #[custom_resolve] attribute: #[custom_resolve({})]", path_to_string(&meta.path())),
289                }
290            }
291        }
292        return None;
293    }
294}
295
296fn path_to_string(path: &syn::Path) -> String {
297    path.segments.iter().map(|s| s.ident.to_string()).collect::<Vec<String>>().join("::")
298}
299
300fn lit_str_to_token_stream(s: &syn::LitStr) -> Result<TokenStream2, proc_macro2::LexError> {
301    let code = s.value();
302    let ts: TokenStream2 = code.parse()?;
303    Ok(set_ts_span_recursive(ts, &s.span()))
304}
305
306fn set_ts_span_recursive(ts: TokenStream2, span: &proc_macro2::Span) -> TokenStream2 {
307    ts.into_iter().map(|mut tt| {
308        tt.set_span(span.clone());
309        if let proc_macro2::TokenTree::Group(group) = &mut tt {
310            let stream = set_ts_span_recursive(group.stream(), span);
311            *group = proc_macro2::Group::new(group.delimiter(), stream);
312        }
313        tt
314    }).collect()
315}