dill_impl/
lib.rs

1extern crate proc_macro;
2
3mod types;
4
5use proc_macro::TokenStream;
6use quote::{format_ident, quote};
7use types::InjectionType;
8
9/////////////////////////////////////////////////////////////////////////////////////////
10
11#[proc_macro_attribute]
12pub fn component(attr: TokenStream, item: TokenStream) -> TokenStream {
13    let ast: syn::Item = syn::parse(item).unwrap();
14    let vis: syn::Visibility = syn::parse(attr).unwrap();
15    match ast {
16        syn::Item::Struct(struct_ast) => component_from_struct(struct_ast),
17        syn::Item::Impl(impl_ast) => component_from_impl(vis, impl_ast),
18        _ => {
19            panic!("The #[component] macro can only be used on struct definition or an impl block")
20        }
21    }
22}
23
24/////////////////////////////////////////////////////////////////////////////////////////
25
26#[proc_macro_attribute]
27pub fn scope(_args: TokenStream, item: TokenStream) -> TokenStream {
28    item
29}
30
31/////////////////////////////////////////////////////////////////////////////////////////
32
33#[proc_macro_attribute]
34pub fn interface(_args: TokenStream, item: TokenStream) -> TokenStream {
35    item
36}
37
38/////////////////////////////////////////////////////////////////////////////////////////
39
40#[proc_macro_attribute]
41pub fn meta(_args: TokenStream, item: TokenStream) -> TokenStream {
42    item
43}
44
45/////////////////////////////////////////////////////////////////////////////////////////
46
47fn component_from_struct(mut ast: syn::ItemStruct) -> TokenStream {
48    let impl_name = &ast.ident;
49    let impl_type = syn::parse2(quote! { #impl_name }).unwrap();
50    let impl_generics = syn::parse2(quote! {}).unwrap();
51
52    let args: Vec<_> = ast
53        .fields
54        .iter_mut()
55        .map(|f| {
56            (
57                f.ident.clone().unwrap(),
58                f.ty.clone(),
59                extract_attr_explicit(&mut f.attrs),
60            )
61        })
62        .collect();
63
64    let scope_type =
65        get_scope(&ast.attrs).unwrap_or_else(|| syn::parse_str("::dill::Transient").unwrap());
66
67    let interfaces = get_interfaces(&ast.attrs);
68    let meta = get_meta(&ast.attrs);
69
70    let mut gen: TokenStream = quote! { #ast }.into();
71    let builder: TokenStream = implement_builder(
72        &ast.vis,
73        &impl_type,
74        &impl_generics,
75        scope_type,
76        interfaces,
77        meta,
78        args,
79        false,
80    );
81
82    gen.extend(builder);
83    gen
84}
85
86/////////////////////////////////////////////////////////////////////////////////////////
87
88fn component_from_impl(vis: syn::Visibility, mut ast: syn::ItemImpl) -> TokenStream {
89    let impl_generics = &ast.generics;
90    let impl_type = &ast.self_ty;
91    let new = get_new(&mut ast.items).expect(
92        "When using #[component] macro on the impl block it's expected to contain a new() \
93         function. Otherwise use #[derive(Builder)] on the struct.",
94    );
95
96    let args: Vec<_> = new
97        .sig
98        .inputs
99        .iter_mut()
100        .map(|arg| match arg {
101            syn::FnArg::Typed(targ) => targ,
102            _ => panic!("Unexpected argument in new() function"),
103        })
104        .map(|arg| {
105            (
106                match arg.pat.as_ref() {
107                    syn::Pat::Ident(ident) => ident.ident.clone(),
108                    _ => panic!("Unexpected format of arguments in new() function"),
109                },
110                arg.ty.as_ref().clone(),
111                extract_attr_explicit(&mut arg.attrs),
112            )
113        })
114        .collect();
115
116    let scope_type =
117        get_scope(&ast.attrs).unwrap_or_else(|| syn::parse_str("::dill::Transient").unwrap());
118
119    let interfaces = get_interfaces(&ast.attrs);
120    let meta = get_meta(&ast.attrs);
121
122    let mut gen: TokenStream = quote! { #ast }.into();
123    let builder: TokenStream = implement_builder(
124        &vis,
125        impl_type,
126        impl_generics,
127        scope_type,
128        interfaces,
129        meta,
130        args,
131        true,
132    );
133
134    gen.extend(builder);
135    gen
136}
137
138/////////////////////////////////////////////////////////////////////////////////////////
139
140#[allow(clippy::too_many_arguments)]
141fn implement_builder(
142    impl_vis: &syn::Visibility,
143    impl_type: &syn::Type,
144    _impl_generics: &syn::Generics,
145    scope_type: syn::Path,
146    interfaces: Vec<syn::Type>,
147    meta: Vec<syn::ExprStruct>,
148    args: Vec<(syn::Ident, syn::Type, bool)>,
149    has_new: bool,
150) -> TokenStream {
151    let builder_name = format_ident!("{}Builder", quote! { #impl_type }.to_string());
152
153    let arg_name: Vec<_> = args.iter().map(|(name, _, _)| name).collect();
154
155    let meta_provide: Vec<_> = meta
156        .iter()
157        .enumerate()
158        .map(|(i, e)| implement_meta_provide(i, e))
159        .collect();
160    let meta_vars: Vec<_> = meta
161        .iter()
162        .enumerate()
163        .map(|(i, e)| implement_meta_var(i, e))
164        .collect();
165
166    let mut arg_override_fn_field = Vec::new();
167    let mut arg_override_fn_field_ctor = Vec::new();
168    let mut arg_override_setters = Vec::new();
169    let mut arg_prepare_dependency = Vec::new();
170    let mut arg_provide_dependency = Vec::new();
171    let mut arg_check_dependency = Vec::new();
172
173    for (name, typ, is_explicit) in &args {
174        let (
175            override_fn_field,
176            override_fn_field_ctor,
177            override_setters,
178            prepare_dependency,
179            provide_dependency,
180            check_dependency,
181        ) = implement_arg(name, typ, &builder_name, *is_explicit);
182
183        arg_override_fn_field.push(override_fn_field);
184        arg_override_fn_field_ctor.push(override_fn_field_ctor);
185        arg_override_setters.push(override_setters);
186        arg_prepare_dependency.push(prepare_dependency);
187        arg_provide_dependency.push(provide_dependency);
188        arg_check_dependency.push(check_dependency);
189    }
190
191    let explicit_arg_decl: Vec<_> = args
192        .iter()
193        .filter(|(_, _, is_explicit)| *is_explicit)
194        .map(|(ident, ty, _)| quote! { #ident: #ty })
195        .collect();
196    let explicit_arg_provide: Vec<_> = args
197        .iter()
198        .filter(|(_, _, is_explicit)| *is_explicit)
199        .map(|(ident, _, _)| quote! { #ident })
200        .collect();
201
202    let ctor = if !has_new {
203        quote! {
204            #impl_type {
205                #( #arg_name: #arg_provide_dependency, )*
206            }
207        }
208    } else {
209        quote! {
210            #impl_type::new(#( #arg_provide_dependency, )*)
211        }
212    };
213
214    let component_or_explicit_factory = if explicit_arg_decl.is_empty() {
215        quote! {
216            impl ::dill::Component for #impl_type {
217                type Builder = #builder_name;
218
219                fn register(cat: &mut ::dill::CatalogBuilder) {
220                    cat.add_builder(Self::builder());
221
222                    #(
223                        cat.bind::<#interfaces, #impl_type>();
224                    )*
225                }
226
227                fn builder() -> Self::Builder {
228                    #builder_name::new()
229                }
230            }
231        }
232    } else {
233        quote! {
234            impl #impl_type {
235                pub fn builder(
236                    #(#explicit_arg_decl),*
237                ) -> #builder_name {
238                    #builder_name::new(
239                        #(#explicit_arg_provide),*
240                    )
241                }
242            }
243        }
244    };
245
246    let builder = quote! {
247        #impl_vis struct #builder_name {
248            dill_builder_scope: #scope_type,
249            #(#arg_override_fn_field),*
250        }
251
252        impl #builder_name {
253            #( #meta_vars )*
254
255            pub fn new(
256                #(#explicit_arg_decl),*
257            ) -> Self {
258                Self {
259                    dill_builder_scope: #scope_type::new(),
260                    #(#arg_override_fn_field_ctor),*
261                }
262            }
263
264            #( #arg_override_setters )*
265
266            fn build(&self, cat: &::dill::Catalog) -> Result<#impl_type, ::dill::InjectionError> {
267                use ::dill::DependencySpec;
268                #( #arg_prepare_dependency )*
269                Ok(#ctor)
270            }
271        }
272
273        impl ::dill::Builder for #builder_name {
274            fn instance_type_id(&self) -> ::std::any::TypeId {
275                ::std::any::TypeId::of::<#impl_type>()
276            }
277
278            fn instance_type_name(&self) -> &'static str {
279                ::std::any::type_name::<#impl_type>()
280            }
281
282            fn interfaces(&self, clb: &mut dyn FnMut(&::dill::InterfaceDesc) -> bool) {
283                #(
284                    if !clb(&::dill::InterfaceDesc {
285                        type_id: ::std::any::TypeId::of::<#interfaces>(),
286                        type_name: ::std::any::type_name::<#interfaces>(),
287                    }) { return }
288                )*
289            }
290
291            fn metadata<'a>(&'a self, clb: & mut dyn FnMut(&'a dyn std::any::Any) -> bool) {
292                #( #meta_provide )*
293            }
294
295            fn get_any(&self, cat: &::dill::Catalog) -> Result<::std::sync::Arc<dyn ::std::any::Any + Send + Sync>, ::dill::InjectionError> {
296                Ok(::dill::TypedBuilder::get(self, cat)?)
297            }
298
299            fn check(&self, cat: &::dill::Catalog) -> Result<(), ::dill::ValidationError> {
300                use ::dill::DependencySpec;
301
302                let mut errors = Vec::new();
303                #(
304                if let Err(err) = #arg_check_dependency {
305                    errors.push(err);
306                }
307                )*
308                if errors.len() != 0 {
309                    Err(::dill::ValidationError { errors })
310                } else {
311                    Ok(())
312                }
313            }
314        }
315
316        impl ::dill::TypedBuilder<#impl_type> for #builder_name {
317            fn get(&self, cat: &::dill::Catalog) -> Result<std::sync::Arc<#impl_type>, ::dill::InjectionError> {
318                use ::dill::Scope;
319
320                if let Some(inst) = self.dill_builder_scope.get() {
321                    return Ok(inst.downcast().unwrap());
322                }
323
324                let inst = ::std::sync::Arc::new(self.build(cat)?);
325
326                self.dill_builder_scope.set(inst.clone());
327                Ok(inst)
328            }
329        }
330
331        #(
332            // Allows casting TypedBuider<T> into TypedBuilder<dyn I> for all declared interfaces
333            impl ::dill::TypedBuilderCast<#interfaces> for #builder_name
334            {
335                fn cast(self) -> impl ::dill::TypedBuilder<#interfaces> {
336                    struct B(#builder_name);
337
338                    impl ::dill::Builder for B {
339                        fn instance_type_id(&self) -> ::std::any::TypeId {
340                            self.0.instance_type_id()
341                        }
342                        fn instance_type_name(&self) -> &'static str {
343                            self.0.instance_type_name()
344                        }
345                        fn interfaces(&self, clb: &mut dyn FnMut(&::dill::InterfaceDesc) -> bool) {
346                            self.0.interfaces(clb)
347                        }
348                        fn metadata<'a>(&'a self, clb: &mut dyn FnMut(&'a dyn std::any::Any) -> bool) {
349                            self.0.metadata(clb)
350                        }
351                        fn get_any(&self, cat: &::dill::Catalog) -> Result<std::sync::Arc<dyn std::any::Any + Send + Sync>, ::dill::InjectionError> {
352                            self.0.get_any(cat)
353                        }
354                        fn check(&self, cat: &::dill::Catalog) -> Result<(), ::dill::ValidationError> {
355                            self.0.check(cat)
356                        }
357                    }
358
359                    impl ::dill::TypedBuilder<#interfaces> for B {
360                        fn get(&self, cat: &::dill::Catalog) -> Result<::std::sync::Arc<#interfaces>, ::dill::InjectionError> {
361                            match self.0.get(cat) {
362                                Ok(v) => Ok(v),
363                                Err(e) => Err(e),
364                            }
365                        }
366                    }
367
368                    B(self)
369                }
370            }
371        )*
372    };
373
374    quote! {
375        #component_or_explicit_factory
376
377        #builder
378    }
379    .into()
380}
381
382/////////////////////////////////////////////////////////////////////////////////////////
383
384fn implement_arg(
385    name: &syn::Ident,
386    typ: &syn::Type,
387    builder: &syn::Ident,
388    is_explicit: bool,
389) -> (
390    proc_macro2::TokenStream, // override_fn_field
391    proc_macro2::TokenStream, // override_fn_field_ctor
392    proc_macro2::TokenStream, // override_setters
393    proc_macro2::TokenStream, // prepare_dependency
394    proc_macro2::TokenStream, // provide_dependency
395    proc_macro2::TokenStream, // check_dependency
396) {
397    let override_fn_name = format_ident!("arg_{}_fn", name);
398
399    let injection_type = if is_explicit {
400        InjectionType::Value { typ: typ.clone() }
401    } else {
402        types::deduce_injection_type(typ)
403    };
404
405    // Used to declare the field that stores the override factory function or
406    // an explicit argument
407    let override_fn_field = if is_explicit {
408        quote! { #name: #typ }
409    } else {
410        match &injection_type {
411            InjectionType::Reference { .. } => proc_macro2::TokenStream::new(),
412            _ => quote! {
413                #override_fn_name: Option<Box<dyn Fn(&::dill::Catalog) -> Result<#typ, ::dill::InjectionError> + Send + Sync>>
414            },
415        }
416    };
417
418    // Used initialize the field that stores the override factory function or
419    // an explicit argument
420    let override_fn_field_ctor = if is_explicit {
421        quote! { #name: #name }
422    } else {
423        match &injection_type {
424            InjectionType::Reference { .. } => proc_macro2::TokenStream::new(),
425            _ => quote! { #override_fn_name: None },
426        }
427    };
428
429    // Used to create with_* and with_*_fn setters for dependency overrides
430    let override_setters = if is_explicit {
431        proc_macro2::TokenStream::new()
432    } else {
433        match &injection_type {
434            InjectionType::Reference { .. } => proc_macro2::TokenStream::new(),
435            _ => {
436                let setter_val_name = format_ident!("with_{}", name);
437                let setter_fn_name = format_ident!("with_{}_fn", name);
438                quote! {
439                    pub fn #setter_val_name(mut self, val: #typ) -> #builder {
440                        self.#override_fn_name = Some(Box::new(move |_| Ok(val.clone())));
441                        self
442                    }
443
444                    pub fn #setter_fn_name(
445                        mut self,
446                        fun: impl Fn(&::dill::Catalog) -> Result<#typ, ::dill::InjectionError> + 'static + Send + Sync
447                    ) -> #builder {
448                        self.#override_fn_name = Some(Box::new(fun));
449                        self
450                    }
451                }
452            }
453        }
454    };
455
456    // Used in TBuilder::check() to validate the dependency
457    let check_dependency = if is_explicit {
458        quote! { Ok(()) }
459    } else {
460        let do_check_dependency = get_do_check_dependency(&injection_type);
461        match &injection_type {
462            InjectionType::Reference { .. } => quote! { #do_check_dependency },
463            _ => quote! {
464                match &self.#override_fn_name {
465                    Some(_) => Ok(()),
466                    _ => #do_check_dependency,
467                }
468            },
469        }
470    };
471
472    // Used in TBuilder::build() to extract the dependency from the catalog
473    let prepare_dependency = if is_explicit {
474        proc_macro2::TokenStream::new()
475    } else {
476        let do_get_dependency = get_do_get_dependency(&injection_type);
477        match &injection_type {
478            InjectionType::Reference { .. } => quote! { let #name = #do_get_dependency; },
479            _ => quote! {
480                let #name = match &self.#override_fn_name {
481                    Some(fun) => fun(cat)?,
482                    _ => #do_get_dependency,
483                };
484            },
485        }
486    };
487
488    // Called to provide dependency value to T's constructor
489    let provide_dependency = if is_explicit {
490        quote! { self.#name.clone() }
491    } else {
492        match &injection_type {
493            InjectionType::Reference { .. } => quote! { #name.as_ref() },
494            _ => quote! { #name },
495        }
496    };
497
498    (
499        override_fn_field,
500        override_fn_field_ctor,
501        override_setters,
502        prepare_dependency,
503        provide_dependency,
504        check_dependency,
505    )
506}
507
508/////////////////////////////////////////////////////////////////////////////////////////
509
510fn get_do_check_dependency(injection_type: &InjectionType) -> proc_macro2::TokenStream {
511    match injection_type {
512        InjectionType::Arc { inner } => quote! { ::dill::OneOf::<#inner>::check(cat) },
513        InjectionType::Reference { inner } => quote! { ::dill::OneOf::<#inner>::check(cat) },
514        InjectionType::Option { element } => match element.as_ref() {
515            InjectionType::Arc { inner } => {
516                quote! { ::dill::Maybe::<::dill::OneOf::<#inner>>::check(cat) }
517            }
518            InjectionType::Value { typ } => {
519                quote! { ::dill::Maybe::<::dill::OneOf::<#typ>>::check(cat) }
520            }
521            _ => {
522                unimplemented!("Currently only Option<Arc<Iface>> and Option<Value> are supported")
523            }
524        },
525        InjectionType::Lazy { element } => match element.as_ref() {
526            InjectionType::Arc { inner } => {
527                quote! { ::dill::specs::Lazy::<::dill::OneOf::<#inner>>::check(cat) }
528            }
529            _ => unimplemented!("Currently only Lazy<Arc<Iface>> is supported"),
530        },
531        InjectionType::Vec { item } => match item.as_ref() {
532            InjectionType::Arc { inner } => quote! { ::dill::AllOf::<#inner>::check(cat) },
533            _ => unimplemented!("Currently only Vec<Arc<Iface>> is supported"),
534        },
535        InjectionType::Value { typ } => quote! { ::dill::OneOf::<#typ>::check(cat) },
536    }
537}
538
539fn get_do_get_dependency(injection_type: &InjectionType) -> proc_macro2::TokenStream {
540    match injection_type {
541        InjectionType::Arc { inner } => quote! { ::dill::OneOf::<#inner>::get(cat)? },
542        InjectionType::Reference { inner } => quote! { ::dill::OneOf::<#inner>::get(cat)? },
543        InjectionType::Option { element } => match element.as_ref() {
544            InjectionType::Arc { inner } => {
545                quote! { ::dill::Maybe::<::dill::OneOf::<#inner>>::get(cat)? }
546            }
547            InjectionType::Value { typ } => {
548                quote! { ::dill::Maybe::<::dill::OneOf::<#typ>>::get(cat)?.map(|v| v.as_ref().clone()) }
549            }
550            _ => {
551                unimplemented!("Currently only Option<Arc<Iface>> and Option<Value> are supported")
552            }
553        },
554        InjectionType::Lazy { element } => match element.as_ref() {
555            InjectionType::Arc { inner } => {
556                quote! { ::dill::specs::Lazy::<::dill::OneOf::<#inner>>::get(cat)? }
557            }
558            _ => unimplemented!("Currently only Lazy<Arc<Iface>> is supported"),
559        },
560        InjectionType::Vec { item } => match item.as_ref() {
561            InjectionType::Arc { inner } => quote! { ::dill::AllOf::<#inner>::get(cat)? },
562            _ => unimplemented!("Currently only Vec<Arc<Iface>> is supported"),
563        },
564        InjectionType::Value { typ } => {
565            quote! { ::dill::OneOf::<#typ>::get(cat).map(|v| v.as_ref().clone())? }
566        }
567    }
568}
569
570/////////////////////////////////////////////////////////////////////////////////////////
571
572fn implement_meta_var(index: usize, expr: &syn::ExprStruct) -> proc_macro2::TokenStream {
573    let ident = format_ident!("_meta_{index}");
574    let typ = &expr.path;
575    quote! {
576        const #ident: #typ = #expr;
577    }
578}
579
580fn implement_meta_provide(index: usize, _expr: &syn::ExprStruct) -> proc_macro2::TokenStream {
581    let ident = format_ident!("_meta_{index}");
582    quote! {
583        if !clb(&Self::#ident) { return }
584    }
585}
586
587/////////////////////////////////////////////////////////////////////////////////////////
588
589/// Searches for `#[scope(X)]` attribute and returns `X`
590fn get_scope(attrs: &Vec<syn::Attribute>) -> Option<syn::Path> {
591    let mut scope = None;
592
593    for attr in attrs {
594        if is_dill_attr(attr, "scope") {
595            attr.parse_nested_meta(|meta| {
596                scope = Some(meta.path);
597                Ok(())
598            })
599            .unwrap();
600        }
601    }
602
603    scope
604}
605
606/////////////////////////////////////////////////////////////////////////////////////////
607
608/// Searches for all `#[interface(X)]` attributes and returns all types
609fn get_interfaces(attrs: &Vec<syn::Attribute>) -> Vec<syn::Type> {
610    let mut interfaces = Vec::new();
611
612    for attr in attrs {
613        if is_dill_attr(attr, "interface") {
614            let iface = attr.parse_args().unwrap();
615            interfaces.push(iface);
616        }
617    }
618
619    interfaces
620}
621
622/////////////////////////////////////////////////////////////////////////////////////////
623
624/// Searches for all `#[meta(X)]` attributes and returns all expressions
625fn get_meta(attrs: &Vec<syn::Attribute>) -> Vec<syn::ExprStruct> {
626    let mut meta = Vec::new();
627
628    for attr in attrs {
629        if is_dill_attr(attr, "meta") {
630            let expr = attr.parse_args().unwrap();
631            meta.push(expr);
632        }
633    }
634
635    meta
636}
637
638/////////////////////////////////////////////////////////////////////////////////////////
639
640fn is_dill_attr<I: ?Sized>(attr: &syn::Attribute, ident: &I) -> bool
641where
642    syn::Ident: PartialEq<I>,
643{
644    if attr.path().is_ident(ident) {
645        true
646    } else {
647        attr.path().segments.len() == 2
648            && &attr.path().segments[0].ident == "dill"
649            && attr.path().segments[1].ident == *ident
650    }
651}
652
653/////////////////////////////////////////////////////////////////////////////////////////
654
655/// Searches `impl` block for `new()` method
656fn get_new(impl_items: &mut [syn::ImplItem]) -> Option<&mut syn::ImplItemFn> {
657    impl_items
658        .iter_mut()
659        .filter_map(|i| match i {
660            syn::ImplItem::Fn(m) => Some(m),
661            _ => None,
662        })
663        .find(|m| m.sig.ident == "new")
664}
665
666/////////////////////////////////////////////////////////////////////////////////////////
667
668fn extract_attr_explicit(attrs: &mut Vec<syn::Attribute>) -> bool {
669    let mut present = false;
670    attrs.retain_mut(|attr| {
671        if is_attr_explicit(attr) {
672            present = true;
673            false
674        } else {
675            true
676        }
677    });
678    present
679}
680
681fn is_attr_explicit(attr: &syn::Attribute) -> bool {
682    if !is_dill_attr(attr, "component") {
683        return false;
684    }
685    let syn::Meta::List(meta) = &attr.meta else {
686        return false;
687    };
688    meta.tokens.to_string().contains("explicit")
689}