dill_impl/
lib.rs

1extern crate proc_macro;
2
3mod types;
4
5use proc_macro::TokenStream;
6use quote::{format_ident, quote};
7use types::InjectionType;
8
9/////////////////////////////////////////////////////////////////////////////////////////
10
11#[proc_macro_attribute]
12pub fn component(attr: TokenStream, item: TokenStream) -> TokenStream {
13    let ast: syn::Item = syn::parse(item).unwrap();
14    let vis: syn::Visibility = syn::parse(attr).unwrap();
15    match ast {
16        syn::Item::Struct(struct_ast) => component_from_struct(struct_ast),
17        syn::Item::Impl(impl_ast) => component_from_impl(vis, impl_ast),
18        _ => {
19            panic!("The #[component] macro can only be used on struct definition or an impl block")
20        }
21    }
22}
23
24/////////////////////////////////////////////////////////////////////////////////////////
25
26#[proc_macro_attribute]
27pub fn scope(_args: TokenStream, item: TokenStream) -> TokenStream {
28    item
29}
30
31/////////////////////////////////////////////////////////////////////////////////////////
32
33#[proc_macro_attribute]
34pub fn interface(_args: TokenStream, item: TokenStream) -> TokenStream {
35    item
36}
37
38/////////////////////////////////////////////////////////////////////////////////////////
39
40#[proc_macro_attribute]
41pub fn meta(_args: TokenStream, item: TokenStream) -> TokenStream {
42    item
43}
44
45/////////////////////////////////////////////////////////////////////////////////////////
46
47fn component_from_struct(mut ast: syn::ItemStruct) -> TokenStream {
48    let impl_name = &ast.ident;
49    let impl_type = syn::parse2(quote! { #impl_name }).unwrap();
50    let impl_generics = syn::parse2(quote! {}).unwrap();
51
52    let explicit_args: Vec<_> = ast
53        .fields
54        .iter_mut()
55        .filter_map(|f| {
56            if extract_attr_explicit(&mut f.attrs) {
57                Some((f.ident.clone().unwrap(), f.ty.clone()))
58            } else {
59                None
60            }
61        })
62        .collect();
63
64    let args: Vec<_> = ast
65        .fields
66        .iter()
67        .map(|f| (f.ident.clone().unwrap(), f.ty.clone()))
68        .collect();
69
70    let scope_type =
71        get_scope(&ast.attrs).unwrap_or_else(|| syn::parse_str("::dill::Transient").unwrap());
72
73    let interfaces = get_interfaces(&ast.attrs);
74    let meta = get_meta(&ast.attrs);
75
76    let mut gen: TokenStream = quote! { #ast }.into();
77    let builder: TokenStream = implement_builder(
78        &ast.vis,
79        &impl_type,
80        &impl_generics,
81        scope_type,
82        interfaces,
83        meta,
84        args,
85        explicit_args,
86        false,
87    );
88
89    gen.extend(builder);
90    gen
91}
92
93/////////////////////////////////////////////////////////////////////////////////////////
94
95fn component_from_impl(vis: syn::Visibility, mut ast: syn::ItemImpl) -> TokenStream {
96    let impl_generics = &ast.generics;
97    let impl_type = &ast.self_ty;
98    let new = get_new(&mut ast.items).expect(
99        "When using #[component] macro on the impl block it's expected to contain a new() \
100         function. Otherwise use #[derive(Builder)] on the struct.",
101    );
102
103    let explicit_args: Vec<_> = new
104        .sig
105        .inputs
106        .iter_mut()
107        .filter_map(|f| match f {
108            syn::FnArg::Receiver(_) => None,
109            syn::FnArg::Typed(f) => {
110                if extract_attr_explicit(&mut f.attrs) {
111                    Some((
112                        match f.pat.as_ref() {
113                            syn::Pat::Ident(ident) => ident.ident.clone(),
114                            _ => panic!("Unexpected format of arguments in new() function"),
115                        },
116                        f.ty.as_ref().clone(),
117                    ))
118                } else {
119                    None
120                }
121            }
122        })
123        .collect();
124
125    let args: Vec<_> = new
126        .sig
127        .inputs
128        .iter()
129        .map(|arg| match arg {
130            syn::FnArg::Typed(targ) => targ,
131            _ => panic!("Unexpected argument in new() function"),
132        })
133        .map(|arg| {
134            (
135                match arg.pat.as_ref() {
136                    syn::Pat::Ident(ident) => ident.ident.clone(),
137                    _ => panic!("Unexpected format of arguments in new() function"),
138                },
139                arg.ty.as_ref().clone(),
140            )
141        })
142        .collect();
143
144    let scope_type =
145        get_scope(&ast.attrs).unwrap_or_else(|| syn::parse_str("::dill::Transient").unwrap());
146
147    let interfaces = get_interfaces(&ast.attrs);
148    let meta = get_meta(&ast.attrs);
149
150    let mut gen: TokenStream = quote! { #ast }.into();
151    let builder: TokenStream = implement_builder(
152        &vis,
153        impl_type,
154        impl_generics,
155        scope_type,
156        interfaces,
157        meta,
158        args,
159        explicit_args,
160        true,
161    );
162
163    gen.extend(builder);
164    gen
165}
166
167/////////////////////////////////////////////////////////////////////////////////////////
168
169#[allow(clippy::too_many_arguments)]
170fn implement_builder(
171    impl_vis: &syn::Visibility,
172    impl_type: &syn::Type,
173    _impl_generics: &syn::Generics,
174    scope_type: syn::Path,
175    interfaces: Vec<syn::Type>,
176    meta: Vec<syn::ExprStruct>,
177    args: Vec<(syn::Ident, syn::Type)>,
178    explicit_args: Vec<(syn::Ident, syn::Type)>,
179    has_new: bool,
180) -> TokenStream {
181    let builder_name = format_ident!("{}Builder", quote! { #impl_type }.to_string());
182
183    let arg_name: Vec<_> = args.iter().map(|(name, _)| name).collect();
184
185    let meta_provide: Vec<_> = meta
186        .iter()
187        .enumerate()
188        .map(|(i, e)| implement_meta_provide(i, e))
189        .collect();
190    let meta_vars: Vec<_> = meta
191        .iter()
192        .enumerate()
193        .map(|(i, e)| implement_meta_var(i, e))
194        .collect();
195
196    let mut arg_override_fn_field = Vec::new();
197    let mut arg_override_fn_field_ctor = Vec::new();
198    let mut arg_override_setters = Vec::new();
199    let mut arg_prepare_dependency = Vec::new();
200    let mut arg_provide_dependency = Vec::new();
201    let mut arg_check_dependency = Vec::new();
202
203    for (name, typ) in &args {
204        let (
205            override_fn_field,
206            override_fn_field_ctor,
207            override_setters,
208            prepare_dependency,
209            provide_dependency,
210            check_dependency,
211        ) = implement_arg(name, typ, &builder_name);
212
213        arg_override_fn_field.push(override_fn_field);
214        arg_override_fn_field_ctor.push(override_fn_field_ctor);
215        arg_override_setters.push(override_setters);
216        arg_prepare_dependency.push(prepare_dependency);
217        arg_provide_dependency.push(provide_dependency);
218        arg_check_dependency.push(check_dependency);
219    }
220
221    let explicit_arg: Vec<_> = explicit_args
222        .iter()
223        .map(|(ident, ty)| quote! { #ident: #ty })
224        .collect();
225    let explicit_arg_set: Vec<_> = explicit_args
226        .iter()
227        .map(|(ident, _)| {
228            let setter_val_name = format_ident!("with_{}", ident);
229            quote! { .#setter_val_name(#ident) }
230        })
231        .collect();
232
233    let ctor = if !has_new {
234        quote! {
235            #impl_type {
236                #( #arg_name: #arg_provide_dependency, )*
237            }
238        }
239    } else {
240        quote! {
241            #impl_type::new(#( #arg_provide_dependency, )*)
242        }
243    };
244
245    let component_or_explicit_factory = if explicit_args.is_empty() {
246        quote! {
247            impl ::dill::Component for #impl_type {
248                type Builder = #builder_name;
249
250                fn register(cat: &mut ::dill::CatalogBuilder) {
251                    cat.add_builder(Self::builder());
252
253                    #(
254                        cat.bind::<#interfaces, #impl_type>();
255                    )*
256                }
257
258                fn builder() -> Self::Builder {
259                    #builder_name::new()
260                }
261            }
262        }
263    } else {
264        quote! {
265            impl #impl_type {
266                pub fn builder(
267                    #(#explicit_arg),*
268                ) -> #builder_name {
269                    #builder_name::new()
270                    #(
271                        #explicit_arg_set
272                    )*
273                }
274            }
275        }
276    };
277
278    let builder = quote! {
279        #impl_vis struct #builder_name {
280            scope: #scope_type,
281            #(
282                #arg_override_fn_field
283            )*
284        }
285
286        impl #builder_name {
287            #( #meta_vars )*
288
289            pub fn new() -> Self {
290                Self {
291                    scope: #scope_type::new(),
292                    #(
293                        #arg_override_fn_field_ctor
294                    )*
295                }
296            }
297
298            #( #arg_override_setters )*
299
300            fn build(&self, cat: &::dill::Catalog) -> Result<#impl_type, ::dill::InjectionError> {
301                use ::dill::DependencySpec;
302                #( #arg_prepare_dependency )*
303                Ok(#ctor)
304            }
305        }
306
307        impl ::dill::Builder for #builder_name {
308            fn instance_type_id(&self) -> ::std::any::TypeId {
309                ::std::any::TypeId::of::<#impl_type>()
310            }
311
312            fn instance_type_name(&self) -> &'static str {
313                ::std::any::type_name::<#impl_type>()
314            }
315
316            fn interfaces(&self, clb: &mut dyn FnMut(&::dill::InterfaceDesc) -> bool) {
317                #(
318                    if !clb(&::dill::InterfaceDesc {
319                        type_id: ::std::any::TypeId::of::<#interfaces>(),
320                        type_name: ::std::any::type_name::<#interfaces>(),
321                    }) { return }
322                )*
323            }
324
325            fn metadata<'a>(&'a self, clb: & mut dyn FnMut(&'a dyn std::any::Any) -> bool) {
326                #( #meta_provide )*
327            }
328
329            fn get_any(&self, cat: &::dill::Catalog) -> Result<::std::sync::Arc<dyn ::std::any::Any + Send + Sync>, ::dill::InjectionError> {
330                Ok(::dill::TypedBuilder::get(self, cat)?)
331            }
332
333            fn check(&self, cat: &::dill::Catalog) -> Result<(), ::dill::ValidationError> {
334                use ::dill::DependencySpec;
335
336                let mut errors = Vec::new();
337                #(
338                if let Err(err) = #arg_check_dependency {
339                    errors.push(err);
340                }
341                )*
342                if errors.len() != 0 {
343                    Err(::dill::ValidationError { errors })
344                } else {
345                    Ok(())
346                }
347            }
348        }
349
350        impl ::dill::TypedBuilder<#impl_type> for #builder_name {
351            fn get(&self, cat: &::dill::Catalog) -> Result<std::sync::Arc<#impl_type>, ::dill::InjectionError> {
352                use ::dill::Scope;
353
354                if let Some(inst) = self.scope.get() {
355                    return Ok(inst.downcast().unwrap());
356                }
357
358                let inst = ::std::sync::Arc::new(self.build(cat)?);
359
360                self.scope.set(inst.clone());
361                Ok(inst)
362            }
363        }
364
365        #(
366            // Allows casting TypedBuider<T> into TypedBuilder<dyn I> for all declared interfaces
367            impl ::dill::TypedBuilderCast<#interfaces> for #builder_name
368            {
369                fn cast(self) -> impl ::dill::TypedBuilder<#interfaces> {
370                    struct B(#builder_name);
371
372                    impl ::dill::Builder for B {
373                        fn instance_type_id(&self) -> ::std::any::TypeId {
374                            self.0.instance_type_id()
375                        }
376                        fn instance_type_name(&self) -> &'static str {
377                            self.0.instance_type_name()
378                        }
379                        fn interfaces(&self, clb: &mut dyn FnMut(&::dill::InterfaceDesc) -> bool) {
380                            self.0.interfaces(clb)
381                        }
382                        fn metadata<'a>(&'a self, clb: &mut dyn FnMut(&'a dyn std::any::Any) -> bool) {
383                            self.0.metadata(clb)
384                        }
385                        fn get_any(&self, cat: &::dill::Catalog) -> Result<std::sync::Arc<dyn std::any::Any + Send + Sync>, ::dill::InjectionError> {
386                            self.0.get_any(cat)
387                        }
388                        fn check(&self, cat: &::dill::Catalog) -> Result<(), ::dill::ValidationError> {
389                            self.0.check(cat)
390                        }
391                    }
392
393                    impl ::dill::TypedBuilder<#interfaces> for B {
394                        fn get(&self, cat: &::dill::Catalog) -> Result<::std::sync::Arc<#interfaces>, ::dill::InjectionError> {
395                            match self.0.get(cat) {
396                                Ok(v) => Ok(v),
397                                Err(e) => Err(e),
398                            }
399                        }
400                    }
401
402                    B(self)
403                }
404            }
405        )*
406    };
407
408    quote! {
409        #component_or_explicit_factory
410
411        #builder
412    }
413    .into()
414}
415
416/////////////////////////////////////////////////////////////////////////////////////////
417
418fn implement_arg(
419    name: &syn::Ident,
420    typ: &syn::Type,
421    builder: &syn::Ident,
422) -> (
423    proc_macro2::TokenStream, // override_fn_field
424    proc_macro2::TokenStream, // override_fn_field_ctor
425    proc_macro2::TokenStream, // override_setters
426    proc_macro2::TokenStream, // prepare_dependency
427    proc_macro2::TokenStream, // provide_dependency
428    proc_macro2::TokenStream, // check_dependency
429) {
430    let override_fn_name = format_ident!("arg_{}_fn", name);
431
432    let injection_type = types::deduce_injection_type(typ);
433
434    let override_fn_field = match &injection_type {
435        InjectionType::Reference { .. } => proc_macro2::TokenStream::new(),
436        _ => quote! {
437            #override_fn_name: Option<Box<dyn Fn(&::dill::Catalog) -> Result<#typ, ::dill::InjectionError> + Send + Sync>>,
438        },
439    };
440
441    let override_fn_field_ctor = match &injection_type {
442        InjectionType::Reference { .. } => proc_macro2::TokenStream::new(),
443        _ => quote! { #override_fn_name: None, },
444    };
445
446    let override_setters = match &injection_type {
447        InjectionType::Reference { .. } => proc_macro2::TokenStream::new(),
448        _ => {
449            let setter_val_name = format_ident!("with_{}", name);
450            let setter_fn_name = format_ident!("with_{}_fn", name);
451            quote! {
452                pub fn #setter_val_name(mut self, val: #typ) -> #builder {
453                    self.#override_fn_name = Some(Box::new(move |_| Ok(val.clone())));
454                    self
455                }
456
457                pub fn #setter_fn_name(
458                    mut self,
459                    fun: impl Fn(&::dill::Catalog) -> Result<#typ, ::dill::InjectionError> + 'static + Send + Sync
460                ) -> #builder {
461                    self.#override_fn_name = Some(Box::new(fun));
462                    self
463                }
464            }
465        }
466    };
467
468    // TODO: Make these rules recursive
469    let do_check_dependency = get_do_check_dependency(&injection_type);
470    let check_dependency = match &injection_type {
471        InjectionType::Reference { .. } => quote! { #do_check_dependency },
472        _ => quote! {
473            match &self.#override_fn_name {
474                Some(_) => Ok(()),
475                _ => #do_check_dependency,
476            }
477        },
478    };
479
480    let do_get_dependency = get_do_get_dependency(&injection_type);
481    let prepare_dependency = match &injection_type {
482        InjectionType::Reference { .. } => quote! { let #name = #do_get_dependency; },
483        _ => quote! {
484            let #name = match &self.#override_fn_name {
485                Some(fun) => fun(cat)?,
486                _ => #do_get_dependency,
487            };
488        },
489    };
490
491    let provide_dependency = match &injection_type {
492        InjectionType::Reference { .. } => quote! { #name.as_ref() },
493        _ => quote! { #name },
494    };
495
496    (
497        override_fn_field,
498        override_fn_field_ctor,
499        override_setters,
500        prepare_dependency,
501        provide_dependency,
502        check_dependency,
503    )
504}
505
506/////////////////////////////////////////////////////////////////////////////////////////
507
508fn get_do_check_dependency(injection_type: &InjectionType) -> proc_macro2::TokenStream {
509    match injection_type {
510        InjectionType::Arc { inner } => quote! { ::dill::OneOf::<#inner>::check(cat) },
511        InjectionType::Reference { inner } => quote! { ::dill::OneOf::<#inner>::check(cat) },
512        InjectionType::Option { element } => match element.as_ref() {
513            InjectionType::Arc { inner } => {
514                quote! { ::dill::Maybe::<::dill::OneOf::<#inner>>::check(cat) }
515            }
516            InjectionType::Value { typ } => {
517                quote! { ::dill::Maybe::<::dill::OneOf::<#typ>>::check(cat) }
518            }
519            _ => {
520                unimplemented!("Currently only Option<Arc<Iface>> and Option<Value> are supported")
521            }
522        },
523        InjectionType::Lazy { element } => match element.as_ref() {
524            InjectionType::Arc { inner } => {
525                quote! { ::dill::specs::Lazy::<::dill::OneOf::<#inner>>::check(cat) }
526            }
527            _ => unimplemented!("Currently only Lazy<Arc<Iface>> is supported"),
528        },
529        InjectionType::Vec { item } => match item.as_ref() {
530            InjectionType::Arc { inner } => quote! { ::dill::AllOf::<#inner>::check(cat) },
531            _ => unimplemented!("Currently only Vec<Arc<Iface>> is supported"),
532        },
533        InjectionType::Value { typ } => quote! { ::dill::OneOf::<#typ>::check(cat) },
534    }
535}
536
537fn get_do_get_dependency(injection_type: &InjectionType) -> proc_macro2::TokenStream {
538    match injection_type {
539        InjectionType::Arc { inner } => quote! { ::dill::OneOf::<#inner>::get(cat)? },
540        InjectionType::Reference { inner } => quote! { ::dill::OneOf::<#inner>::get(cat)? },
541        InjectionType::Option { element } => match element.as_ref() {
542            InjectionType::Arc { inner } => {
543                quote! { ::dill::Maybe::<::dill::OneOf::<#inner>>::get(cat)? }
544            }
545            InjectionType::Value { typ } => {
546                quote! { ::dill::Maybe::<::dill::OneOf::<#typ>>::get(cat)?.map(|v| v.as_ref().clone()) }
547            }
548            _ => {
549                unimplemented!("Currently only Option<Arc<Iface>> and Option<Value> are supported")
550            }
551        },
552        InjectionType::Lazy { element } => match element.as_ref() {
553            InjectionType::Arc { inner } => {
554                quote! { ::dill::specs::Lazy::<::dill::OneOf::<#inner>>::get(cat)? }
555            }
556            _ => unimplemented!("Currently only Lazy<Arc<Iface>> is supported"),
557        },
558        InjectionType::Vec { item } => match item.as_ref() {
559            InjectionType::Arc { inner } => quote! { ::dill::AllOf::<#inner>::get(cat)? },
560            _ => unimplemented!("Currently only Vec<Arc<Iface>> is supported"),
561        },
562        InjectionType::Value { typ } => {
563            quote! { ::dill::OneOf::<#typ>::get(cat).map(|v| v.as_ref().clone())? }
564        }
565    }
566}
567
568/////////////////////////////////////////////////////////////////////////////////////////
569
570fn implement_meta_var(index: usize, expr: &syn::ExprStruct) -> proc_macro2::TokenStream {
571    let ident = format_ident!("_meta_{index}");
572    let typ = &expr.path;
573    quote! {
574        const #ident: #typ = #expr;
575    }
576}
577
578fn implement_meta_provide(index: usize, _expr: &syn::ExprStruct) -> proc_macro2::TokenStream {
579    let ident = format_ident!("_meta_{index}");
580    quote! {
581        if !clb(&Self::#ident) { return }
582    }
583}
584
585/////////////////////////////////////////////////////////////////////////////////////////
586
587/// Searches for `#[scope(X)]` attribute and returns `X`
588fn get_scope(attrs: &Vec<syn::Attribute>) -> Option<syn::Path> {
589    let mut scope = None;
590
591    for attr in attrs {
592        if is_dill_attr(attr, "scope") {
593            attr.parse_nested_meta(|meta| {
594                scope = Some(meta.path);
595                Ok(())
596            })
597            .unwrap();
598        }
599    }
600
601    scope
602}
603
604/////////////////////////////////////////////////////////////////////////////////////////
605
606/// Searches for all `#[interface(X)]` attributes and returns all types
607fn get_interfaces(attrs: &Vec<syn::Attribute>) -> Vec<syn::Type> {
608    let mut interfaces = Vec::new();
609
610    for attr in attrs {
611        if is_dill_attr(attr, "interface") {
612            let iface = attr.parse_args().unwrap();
613            interfaces.push(iface);
614        }
615    }
616
617    interfaces
618}
619
620/////////////////////////////////////////////////////////////////////////////////////////
621
622/// Searches for all `#[meta(X)]` attributes and returns all expressions
623fn get_meta(attrs: &Vec<syn::Attribute>) -> Vec<syn::ExprStruct> {
624    let mut meta = Vec::new();
625
626    for attr in attrs {
627        if is_dill_attr(attr, "meta") {
628            let expr = attr.parse_args().unwrap();
629            meta.push(expr);
630        }
631    }
632
633    meta
634}
635
636/////////////////////////////////////////////////////////////////////////////////////////
637
638fn is_dill_attr<I: ?Sized>(attr: &syn::Attribute, ident: &I) -> bool
639where
640    syn::Ident: PartialEq<I>,
641{
642    if attr.path().is_ident(ident) {
643        true
644    } else {
645        attr.path().segments.len() == 2
646            && &attr.path().segments[0].ident == "dill"
647            && attr.path().segments[1].ident == *ident
648    }
649}
650
651/////////////////////////////////////////////////////////////////////////////////////////
652
653/// Searches `impl` block for `new()` method
654fn get_new(impl_items: &mut [syn::ImplItem]) -> Option<&mut syn::ImplItemFn> {
655    impl_items
656        .iter_mut()
657        .filter_map(|i| match i {
658            syn::ImplItem::Fn(m) => Some(m),
659            _ => None,
660        })
661        .find(|m| m.sig.ident == "new")
662}
663
664/////////////////////////////////////////////////////////////////////////////////////////
665
666fn extract_attr_explicit(attrs: &mut Vec<syn::Attribute>) -> bool {
667    let mut present = false;
668    attrs.retain_mut(|attr| {
669        if is_attr_explicit(attr) {
670            present = true;
671            false
672        } else {
673            true
674        }
675    });
676    present
677}
678
679fn is_attr_explicit(attr: &syn::Attribute) -> bool {
680    if !is_dill_attr(attr, "component") {
681        return false;
682    }
683    let syn::Meta::List(meta) = &attr.meta else {
684        return false;
685    };
686    meta.tokens.to_string().contains("explicit")
687}