di_macros/
lib.rs

1mod alias;
2mod internal;
3
4extern crate proc_macro;
5
6use crate::internal::*;
7use internal::{Constructor, DeriveContext, InjectableTrait};
8use proc_macro2::TokenStream;
9use syn::{
10    punctuated::Punctuated,
11    spanned::Spanned,
12    token::{PathSep, Plus},
13    *,
14};
15
16/// Represents the metadata used to identify an injected function.
17///
18/// # Remarks
19///
20/// The default behavior looks for an associated function with the
21/// name `new`. To change this behavior, decorate the function to
22/// be used with `#[inject]`. This attribute may only be applied
23/// to a single function.
24#[proc_macro_attribute]
25pub fn inject(
26    _metadata: proc_macro::TokenStream,
27    input: proc_macro::TokenStream,
28) -> proc_macro::TokenStream {
29    // this attribute is intentionally inert
30    input
31}
32
33/// Represents the metadata used to implement the `Injectable` trait.
34///
35/// # Arguments
36///
37/// * `trait` - the optional name of the trait the implementation satisfies.
38///
39/// # Remarks
40///
41/// This attribute may be applied a struct definition or a struct `impl`
42/// block. The defining struct implementation block must either have an
43/// associated function named `new` or decorate the injected function with
44/// `#[inject]`. The injected function does not have to be public.
45///
46/// If `trait` is not specified, then the implementation will
47/// injectable as the defining struct itself.
48///
49/// The injected call site arguments are restricted to the same return
50/// values supported by `ServiceProvider`, which can only be:
51///
52/// * `Ref<T>`
53/// * `RefMut<T>`
54/// * `Option<Ref<T>>`
55/// * `Option<RefMut<T>>`
56/// * `Vec<Ref<T>>`
57/// * `Vec<RefMut<T>>`
58/// * `impl Iterator<Item = Ref<T>>`
59/// * `impl Iterator<Item = RefMut<T>>`
60/// * `Lazy<Ref<T>>`
61/// * `Lazy<RefMut<T>>`
62/// * `Lazy<Option<Ref<T>>>`
63/// * `Lazy<Option<RefMut<T>>>`
64/// * `Lazy<Vec<Ref<T>>>`
65/// * `Lazy<Vec<RefMut<T>>>`
66/// * `KeyedRef<TKey, TSvc>`
67/// * `KeyedRefMut<TKey, TSvc>`
68/// * `ServiceProvider`
69/// * `ScopedServiceProvider`
70///
71/// `Ref<T>` is a type alias for `Rc<T>` or `Arc<T>` and `RefMut<T>` is a
72/// type alias for `Rc<RefCell<T>>` or `Arc<RwLock<T>>` depending on whether
73/// the **async** feature is activated; therefore, `Rc<T>` and `Arc<T>`
74/// are allowed any place `Ref<T>` is allowed and `Rc<RefCell<T>>`
75/// and `Arc<RwLock<T>>` are allowed any place `RefMut<T>` is allowed.
76///
77/// # Examples
78///
79/// Injecting a struct as a trait.
80///
81/// ```
82/// pub trait Foo {
83///    fn do_work(&self);
84/// }
85///
86/// pub struct FooImpl;
87///
88/// impl Foo for FooImpl {
89///     fn do_work(&self) {
90///         println!("Did something!");
91///     }
92/// }
93///
94/// #[injectable(Foo)]
95/// impl FooImpl {
96///     pub fn new() -> Self {
97///         Self {}
98///     }
99/// }
100/// ```
101///
102/// Injecting a struct as itself.
103///
104/// ```
105/// #[injectable]
106/// pub struct Foo;
107///
108/// impl Foo {
109///     fn do_work(&self) {
110///         println!("Did something!");
111///     }
112/// }
113/// ```
114///
115/// Define a custom injection function.
116///
117/// ```
118/// pub struct Bar;
119/// pub struct Foo {
120///     bar: di::Ref<Bar>
121/// };
122///
123/// #[injectable]
124/// impl Foo {
125///     #[inject]
126///     pub fn create(bar: di::Ref<Bar>) -> Self {
127///         Self { bar }
128///     }
129/// }
130#[proc_macro_attribute]
131pub fn injectable(
132    metadata: proc_macro::TokenStream,
133    input: proc_macro::TokenStream,
134) -> proc_macro::TokenStream {
135    proc_macro::TokenStream::from(_injectable(
136        TokenStream::from(metadata),
137        TokenStream::from(input),
138    ))
139}
140
141fn _injectable(metadata: TokenStream, input: TokenStream) -> TokenStream {
142    let original = TokenStream::from(input.clone());
143    let result = match parse2::<InjectableAttribute>(metadata) {
144        Ok(attribute) => {
145            if let Ok(impl_) = parse2::<ItemImpl>(TokenStream::from(input.clone())) {
146                derive_from_struct_impl(impl_, attribute, original)
147            } else if let Ok(struct_) = parse2::<ItemStruct>(TokenStream::from(input)) {
148                derive_from_struct(struct_, attribute, original)
149            } else {
150                Err(Error::new(
151                    original.span(),
152                    "Attribute can only be applied to a structure or structure implementation block.",
153                ))
154            }
155        }
156        Err(error) => Err(error),
157    };
158
159    match result {
160        Ok(output) => output,
161        Err(error) => error.to_compile_error().into(),
162    }
163}
164
165fn derive_from_struct_impl(
166    impl_: ItemImpl,
167    attribute: InjectableAttribute,
168    original: TokenStream,
169) -> Result<TokenStream> {
170    if let Type::Path(type_) = &*impl_.self_ty {
171        let imp = &type_.path;
172        let svc = service_from_attribute(imp, attribute);
173        match Constructor::select(&impl_, imp) {
174            Ok(method) => {
175                let context = DeriveContext::for_method(&impl_.generics, imp, svc, method);
176                derive(context, original)
177            }
178            Err(error) => Err(error),
179        }
180    } else {
181        Err(Error::new(impl_.span(), "Expected implementation type."))
182    }
183}
184
185fn derive_from_struct(
186    struct_: ItemStruct,
187    attribute: InjectableAttribute,
188    original: TokenStream,
189) -> Result<TokenStream> {
190    let imp = &build_path_from_struct(&struct_);
191    let svc = service_from_attribute(imp, attribute);
192    let context = DeriveContext::for_struct(&struct_.generics, imp, svc, &struct_);
193
194    derive(context, original)
195}
196
197fn service_from_attribute(impl_: &Path, mut attribute: InjectableAttribute) -> Punctuated<Path, Plus> {
198    let mut punctuated = attribute.trait_.take().unwrap_or_else(Punctuated::<Path, Plus>::new);
199
200    if punctuated.is_empty() {
201        punctuated.push(impl_.clone());
202    }
203
204    punctuated
205}
206
207fn build_path_from_struct(struct_: &ItemStruct) -> Path {
208    let generics = &struct_.generics;
209    let mut segments = Punctuated::<PathSegment, PathSep>::new();
210    let segment = PathSegment {
211        ident: struct_.ident.clone(),
212        arguments: if generics.params.is_empty() {
213            PathArguments::None
214        } else {
215            let mut args = Punctuated::<GenericArgument, Token![,]>::new();
216
217            for param in &generics.params {
218                args.push(match param {
219                    GenericParam::Const(_) => continue,
220                    GenericParam::Type(type_) => GenericArgument::Type(Type::Path(TypePath {
221                        qself: None,
222                        path: Path::from(type_.ident.clone()),
223                    })),
224                    GenericParam::Lifetime(param) => {
225                        GenericArgument::Lifetime(param.lifetime.clone())
226                    }
227                });
228            }
229
230            PathArguments::AngleBracketed(AngleBracketedGenericArguments {
231                colon2_token: None,
232                gt_token: Default::default(),
233                args,
234                lt_token: Default::default(),
235            })
236        },
237    };
238
239    segments.push(segment);
240
241    Path {
242        leading_colon: None,
243        segments,
244    }
245}
246
247#[inline]
248fn derive<'a>(context: DeriveContext<'a>, mut original: TokenStream) -> Result<TokenStream> {
249    match InjectableTrait::derive(&context) {
250        Ok(injectable) => {
251            original.extend(injectable.into_iter());
252            Ok(original)
253        }
254        Err(error) => Err(error),
255    }
256}
257
258#[cfg(test)]
259mod test {
260    use super::*;
261    use std::str::FromStr;
262
263    #[test]
264    fn attribute_should_implement_injectable_by_convention() {
265        // arrange
266        let metadata = TokenStream::from_str(r#"Foo"#).unwrap();
267        let input = TokenStream::from_str(
268            r#"
269            impl FooImpl {
270                fn new() -> Self {
271                    Self { }
272                }
273            }
274        "#,
275        )
276        .unwrap();
277
278        // act
279        let result = _injectable(metadata, input);
280
281        // assert
282        let expected = concat!(
283            "impl FooImpl { ",
284            "fn new () -> Self { ",
285            "Self { } ",
286            "} ",
287            "} ",
288            "impl di :: Injectable for FooImpl { ",
289            "fn inject (lifetime : di :: ServiceLifetime) -> di :: InjectBuilder { ",
290            "di :: InjectBuilder :: new (",
291            "di :: Activator :: new :: < dyn Foo , Self > (",
292            "| sp : & di :: ServiceProvider | di :: Ref :: new (Self :: new ()) , ",
293            "| sp : & di :: ServiceProvider | di :: RefMut :: new (Self :: new () . into ())) , ",
294            "lifetime) ",
295            "} ",
296            "}"
297        );
298
299        assert_eq!(expected, result.to_string());
300    }
301
302    #[test]
303    fn attribute_should_implement_injectable_using_decorated_method() {
304        // arrange
305        let metadata = TokenStream::from_str(r#"Foo"#).unwrap();
306        let input = TokenStream::from_str(
307            r#"
308            impl FooImpl {
309                #[inject]
310                fn create() -> Self {
311                    Self { }
312                }
313            }
314        "#,
315        )
316        .unwrap();
317
318        // act
319        let result = _injectable(metadata, input);
320
321        // assert
322        let expected = concat!(
323            "impl FooImpl { ",
324            "# [inject] ",
325            "fn create () -> Self { ",
326            "Self { } ",
327            "} ",
328            "} ",
329            "impl di :: Injectable for FooImpl { ",
330            "fn inject (lifetime : di :: ServiceLifetime) -> di :: InjectBuilder { ",
331            "di :: InjectBuilder :: new (",
332            "di :: Activator :: new :: < dyn Foo , Self > (",
333            "| sp : & di :: ServiceProvider | di :: Ref :: new (Self :: create ()) , ",
334            "| sp : & di :: ServiceProvider | di :: RefMut :: new (Self :: create () . into ())) , ",
335            "lifetime) ",
336            "} ",
337            "}");
338
339        assert_eq!(expected, result.to_string());
340    }
341
342    #[test]
343    fn attribute_should_inject_required_dependency() {
344        // arrange
345        let metadata = TokenStream::from_str(r#"Foo"#).unwrap();
346        let input = TokenStream::from_str(
347            r#"
348            impl FooImpl {
349                fn new(_bar: Rc<dyn Bar>) -> Self {
350                    Self { }
351                }
352            }
353        "#,
354        )
355        .unwrap();
356
357        // act
358        let result = _injectable(metadata, input);
359
360        // assert
361        let expected = concat!(
362            "impl FooImpl { ",
363            "fn new (_bar : Rc < dyn Bar >) -> Self { ",
364            "Self { } ",
365            "} ",
366            "} ",
367            "impl di :: Injectable for FooImpl { ",
368            "fn inject (lifetime : di :: ServiceLifetime) -> di :: InjectBuilder { ",
369            "di :: InjectBuilder :: new (",
370            "di :: Activator :: new :: < dyn Foo , Self > (",
371            "| sp : & di :: ServiceProvider | di :: Ref :: new (Self :: new (sp . get_required :: < dyn Bar > ())) , ",
372            "| sp : & di :: ServiceProvider | di :: RefMut :: new (Self :: new (sp . get_required :: < dyn Bar > ()) . into ())) , ",
373            "lifetime) ",
374            ". depends_on (di :: ServiceDependency :: new (di :: Type :: of :: < dyn Bar > () , di :: ServiceCardinality :: ExactlyOne)) ",
375            "} ",
376            "}");
377
378        assert_eq!(expected, result.to_string());
379    }
380
381    #[test]
382    fn attribute_should_inject_optional_dependency() {
383        // arrange
384        let metadata = TokenStream::from_str(r#"Foo"#).unwrap();
385        let input = TokenStream::from_str(
386            r#"
387            impl FooImpl {
388                fn new(_bar: Option<Rc<dyn Bar>>) -> Self {
389                    Self { }
390                }
391            }
392        "#,
393        )
394        .unwrap();
395
396        // act
397        let result = _injectable(metadata, input);
398
399        // assert
400        let expected = concat!(
401            "impl FooImpl { ",
402            "fn new (_bar : Option < Rc < dyn Bar >>) -> Self { ",
403            "Self { } ",
404            "} ",
405            "} ",
406            "impl di :: Injectable for FooImpl { ",
407            "fn inject (lifetime : di :: ServiceLifetime) -> di :: InjectBuilder { ",
408            "di :: InjectBuilder :: new (",
409            "di :: Activator :: new :: < dyn Foo , Self > (",
410            "| sp : & di :: ServiceProvider | di :: Ref :: new (Self :: new (sp . get :: < dyn Bar > ())) , ",
411            "| sp : & di :: ServiceProvider | di :: RefMut :: new (Self :: new (sp . get :: < dyn Bar > ()) . into ())) , ",
412            "lifetime) ",
413            ". depends_on (di :: ServiceDependency :: new (di :: Type :: of :: < dyn Bar > () , di :: ServiceCardinality :: ZeroOrOne)) ",
414            "} ",
415            "}");
416
417        assert_eq!(expected, result.to_string());
418    }
419
420    #[test]
421    fn attribute_should_inject_dependency_collection() {
422        // arrange
423        let metadata = TokenStream::from_str(r#"Foo"#).unwrap();
424        let input = TokenStream::from_str(
425            r#"
426            impl FooImpl {
427                fn new(_bars: Vec<Rc<dyn Bar>>) -> Self {
428                    Self { }
429                }
430            }
431        "#,
432        )
433        .unwrap();
434
435        // act
436        let result = _injectable(metadata, input);
437
438        // assert
439        let expected = concat!(
440            "impl FooImpl { ",
441            "fn new (_bars : Vec < Rc < dyn Bar >>) -> Self { ",
442            "Self { } ",
443            "} ",
444            "} ",
445            "impl di :: Injectable for FooImpl { ",
446            "fn inject (lifetime : di :: ServiceLifetime) -> di :: InjectBuilder { ",
447            "di :: InjectBuilder :: new (",
448            "di :: Activator :: new :: < dyn Foo , Self > (",
449            "| sp : & di :: ServiceProvider | di :: Ref :: new (Self :: new (sp . get_all :: < dyn Bar > () . collect ())) , ",
450            "| sp : & di :: ServiceProvider | di :: RefMut :: new (Self :: new (sp . get_all :: < dyn Bar > () . collect ()) . into ())) , ",
451            "lifetime) ",
452            ". depends_on (di :: ServiceDependency :: new (di :: Type :: of :: < dyn Bar > () , di :: ServiceCardinality :: ZeroOrMore)) ",
453            "} ",
454            "}");
455
456        assert_eq!(expected, result.to_string());
457    }
458
459    #[test]
460    fn attribute_should_inject_multiple_dependencies() {
461        // arrange
462        let metadata = TokenStream::from_str(r#"Thing"#).unwrap();
463        let input = TokenStream::from_str(
464            r#"
465            impl ThingImpl {
466                #[inject]
467                fn create_new(_foo: Ref<dyn Foo>, _bar: Option<Ref<dyn Bar>>) -> Self {
468                    Self { }
469                }
470            }
471        "#,
472        )
473        .unwrap();
474
475        // act
476        let result = _injectable(metadata, input);
477
478        // assert
479        let expected = concat!(
480            "impl ThingImpl { ",
481            "# [inject] ",
482            "fn create_new (_foo : Ref < dyn Foo >, _bar : Option < Ref < dyn Bar >>) -> Self { ",
483            "Self { } ",
484            "} ",
485            "} ",
486            "impl di :: Injectable for ThingImpl { ",
487            "fn inject (lifetime : di :: ServiceLifetime) -> di :: InjectBuilder { ",
488            "di :: InjectBuilder :: new (",
489            "di :: Activator :: new :: < dyn Thing , Self > (",
490            "| sp : & di :: ServiceProvider | di :: Ref :: new (Self :: create_new (sp . get_required :: < dyn Foo > () , sp . get :: < dyn Bar > ())) , ",
491            "| sp : & di :: ServiceProvider | di :: RefMut :: new (Self :: create_new (sp . get_required :: < dyn Foo > () , sp . get :: < dyn Bar > ()) . into ())) , ",
492            "lifetime) ",
493            ". depends_on (di :: ServiceDependency :: new (di :: Type :: of :: < dyn Foo > () , di :: ServiceCardinality :: ExactlyOne)) ",
494            ". depends_on (di :: ServiceDependency :: new (di :: Type :: of :: < dyn Bar > () , di :: ServiceCardinality :: ZeroOrOne)) ",
495            "} ",
496            "}");
497
498        assert_eq!(expected, result.to_string());
499    }
500
501    #[test]
502    fn attribute_should_implement_injectable_for_self() {
503        // arrange
504        let metadata = TokenStream::new();
505        let input = TokenStream::from_str(
506            r#"
507            impl FooImpl {
508                fn new() -> Self {
509                    Self { }
510                }
511            }
512        "#,
513        )
514        .unwrap();
515
516        // act
517        let result = _injectable(metadata, input);
518
519        // assert
520        let expected = concat!(
521            "impl FooImpl { ",
522            "fn new () -> Self { ",
523            "Self { } ",
524            "} ",
525            "} ",
526            "impl di :: Injectable for FooImpl { ",
527            "fn inject (lifetime : di :: ServiceLifetime) -> di :: InjectBuilder { ",
528            "di :: InjectBuilder :: new (",
529            "di :: Activator :: new :: < Self , Self > (",
530            "| sp : & di :: ServiceProvider | di :: Ref :: new (Self :: new ()) , ",
531            "| sp : & di :: ServiceProvider | di :: RefMut :: new (Self :: new () . into ())) , ",
532            "lifetime) ",
533            "} ",
534            "}"
535        );
536
537        assert_eq!(expected, result.to_string());
538    }
539
540    #[test]
541    fn attribute_should_implement_injectable_for_struct() {
542        // arrange
543        let metadata = TokenStream::from_str(r#"Foo"#).unwrap();
544        let input = TokenStream::from_str(
545            r#"
546            impl FooImpl {
547                fn new(_bar: Rc<Bar>) -> Self {
548                    Self { }
549                }
550            }
551        "#,
552        )
553        .unwrap();
554
555        // act
556        let result = _injectable(metadata, input);
557
558        // assert
559        let expected = concat!(
560            "impl FooImpl { ",
561            "fn new (_bar : Rc < Bar >) -> Self { ",
562            "Self { } ",
563            "} ",
564            "} ",
565            "impl di :: Injectable for FooImpl { ",
566            "fn inject (lifetime : di :: ServiceLifetime) -> di :: InjectBuilder { ",
567            "di :: InjectBuilder :: new (",
568            "di :: Activator :: new :: < dyn Foo , Self > (",
569            "| sp : & di :: ServiceProvider | di :: Ref :: new (Self :: new (sp . get_required :: < Bar > ())) , ",
570            "| sp : & di :: ServiceProvider | di :: RefMut :: new (Self :: new (sp . get_required :: < Bar > ()) . into ())) , ",
571            "lifetime) ",
572            ". depends_on (di :: ServiceDependency :: new (di :: Type :: of :: < Bar > () , di :: ServiceCardinality :: ExactlyOne)) ",
573            "} ",
574            "}");
575
576        assert_eq!(expected, result.to_string());
577    }
578
579    #[test]
580    fn attribute_should_implement_injectable_for_generic_struct() {
581        // arrange
582        let metadata = TokenStream::new();
583        let input = TokenStream::from_str(
584            r#"
585            impl<T: Default> GenericBar<T> {
586                fn new() -> Self {
587                    Self { }
588                }
589            }
590        "#,
591        )
592        .unwrap();
593
594        // act
595        let result = _injectable(metadata, input);
596
597        // assert
598        let expected = concat!(
599            "impl < T : Default > GenericBar < T > { ",
600            "fn new () -> Self { ",
601            "Self { } ",
602            "} ",
603            "} ",
604            "impl < T : Default > di :: Injectable for GenericBar < T > { ",
605            "fn inject (lifetime : di :: ServiceLifetime) -> di :: InjectBuilder { ",
606            "di :: InjectBuilder :: new (",
607            "di :: Activator :: new :: < Self , Self > (",
608            "| sp : & di :: ServiceProvider | di :: Ref :: new (Self :: new ()) , ",
609            "| sp : & di :: ServiceProvider | di :: RefMut :: new (Self :: new () . into ())) , ",
610            "lifetime) ",
611            "} ",
612            "}"
613        );
614
615        assert_eq!(expected, result.to_string());
616    }
617
618    #[test]
619    fn attribute_should_implement_injectable_for_generic_trait() {
620        // arrange
621        let metadata = TokenStream::from_str(r#"Pair<TKey, TValue>"#).unwrap();
622        let input = TokenStream::from_str(
623            r#"
624            impl<TKey, TValue> PairImpl<TKey, TValue>
625            where
626                TKey: Debug,
627                TValue: Debug
628            {
629                fn new(key: Ref<TKey>, value: Ref<TValue>) -> Self {
630                    Self { key, value }
631                }
632            }
633        "#,
634        )
635        .unwrap();
636
637        // act
638        let result = _injectable(metadata, input);
639
640        // assert
641        let expected = concat!(
642            "impl < TKey , TValue > PairImpl < TKey , TValue > ",
643            "where ",
644            "TKey : Debug , ",
645            "TValue : Debug ",
646            "{ ",
647            "fn new (key : Ref < TKey >, value : Ref < TValue >) -> Self { ",
648            "Self { key , value } ",
649            "} ",
650            "} ",
651            "impl < TKey , TValue > di :: Injectable for PairImpl < TKey , TValue > ",
652            "where ",
653            "TKey : Debug , ",
654            "TValue : Debug ",
655            "{ ",
656            "fn inject (lifetime : di :: ServiceLifetime) -> di :: InjectBuilder { ",
657            "di :: InjectBuilder :: new (",
658            "di :: Activator :: new :: < dyn Pair < TKey , TValue > , Self > (",
659            "| sp : & di :: ServiceProvider | di :: Ref :: new (Self :: new (sp . get_required :: < TKey > () , sp . get_required :: < TValue > ())) , ",
660            "| sp : & di :: ServiceProvider | di :: RefMut :: new (Self :: new (sp . get_required :: < TKey > () , sp . get_required :: < TValue > ()) . into ())) , ",
661            "lifetime) ",
662            ". depends_on (di :: ServiceDependency :: new (di :: Type :: of :: < TKey > () , di :: ServiceCardinality :: ExactlyOne)) ",
663            ". depends_on (di :: ServiceDependency :: new (di :: Type :: of :: < TValue > () , di :: ServiceCardinality :: ExactlyOne)) ",
664            "} ",
665            "}");
666
667        assert_eq!(expected, result.to_string());
668    }
669}