coi_derive/
lib.rs

1#![deny(missing_docs)]
2//! Coi-derive simplifies implementing the traits provided in the [coi] crate.
3//!
4//! [coi]: https://docs.rs/coi
5
6extern crate proc_macro;
7use proc_macro::TokenStream;
8use quote::{format_ident, quote};
9use syn::{parse_macro_input, DeriveInput, Error};
10
11mod attr;
12mod ctxt;
13mod symbol;
14
15use crate::attr::Container;
16use crate::ctxt::Ctxt;
17
18/// Generates an impl for `Inject` and also generates a "Provider" struct with its own
19/// `Provide` impl.
20///
21/// This derive proc macro impls `Inject` on the struct it modifies, and also processes #[coi(...)]
22/// attributes:
23/// - `#[coi(provides ...)]` - It takes the form
24/// ```rust,ignore
25/// #[coi(provides <vis> <ty> with <expr>)]
26/// ```
27/// or the form
28/// ```rust,ignore
29/// #[coi(provides <vis> <ty> as <name> with <expr>)]
30/// ```
31/// The latter form *must* be used when generating multiple providers for a single type. This might
32/// be useful if you have multiple trait implementations for one struct and want to provide separate
33/// unique instances for each trait in the container. That use case might be more common with mocks
34/// in unit tests rather than in production code.
35///
36/// It generates a provider struct with visibility `<vis>`
37/// that impls `Provide` with an output type of `Arc<<ty>>`. It will construct `<ty>` with `<expr>`,
38/// and all params to `<expr>` must match the struct fields marked with `#[coi(inject)]` (see the
39/// next bullet item). `<vis>` must match the visibility of `<ty>` or you will get code that might
40/// not compile. If `<name>` is not provided, the struct name will be used and `Provider` will be
41/// appended to it.
42/// - `#[coi(inject)]` - All fields marked `#[coi(inject)]` are resolved in the `provide` fn
43/// described above.
44/// Given a field `<field_name>: <field_ty>`, this attribute will cause the following resolution to
45/// be generated:
46/// ```rust,ignore
47/// let <field_name> = Container::resolve::<<field_ty>>(container, "<field_name>");
48/// ```
49/// Because of this, it's important that the field name *must* match the string that's used to
50/// register the provider in the `ContainerBuilder`.
51///
52/// ## Examples
53///
54/// Private trait and no dependencies
55/// ```rust
56/// use coi::Inject;
57/// # use coi_derive::Inject;
58/// trait Priv: Inject {}
59///
60/// #[derive(Inject)]
61/// #[coi(provides dyn Priv with SimpleStruct)]
62/// # pub
63/// struct SimpleStruct;
64///
65/// impl Priv for SimpleStruct {}
66/// ```
67///
68/// Public trait and dependency
69/// ```rust
70/// use coi::Inject;
71/// # use coi_derive::Inject;
72/// use std::sync::Arc;
73/// pub trait Pub: Inject {}
74/// pub trait Dependency: Inject {}
75///
76/// #[derive(Inject)]
77/// #[coi(provides pub dyn Pub with NewStruct::new(dependency))]
78/// # pub
79/// struct NewStruct {
80///     #[coi(inject)]
81///     dependency: Arc<dyn Dependency>,
82/// }
83///
84/// impl NewStruct {
85///     fn new(dependency: Arc<dyn Dependency>) -> Self {
86///         Self {
87///             dependency
88///         }
89///     }
90/// }
91///
92/// impl Pub for NewStruct {}
93/// ```
94///
95/// Struct injection
96/// ```rust
97/// use coi::Inject;
98/// # use coi_derive::Inject;
99///
100/// #[derive(Inject)]
101/// #[coi(provides pub InjectableStruct with InjectableStruct)]
102/// # pub
103/// struct InjectableStruct;
104/// ```
105///
106/// Unnamed fields
107/// ```rust
108/// use coi::Inject;
109/// # use coi_derive::Inject;
110/// use std::sync::Arc;
111///
112/// #[derive(Inject)]
113/// #[coi(provides Dep1 with Dep1)]
114/// struct Dep1;
115///
116/// #[derive(Inject)]
117/// #[coi(provides Impl1 with Impl1(dep1))]
118/// struct Impl1(#[coi(inject = "dep1")] Arc<Dep1>);
119/// ```
120///
121/// Generics
122/// ```rust
123/// use coi::{container, Inject};
124/// # use coi_derive::Inject;
125///
126/// #[derive(Inject)]
127/// #[coi(provides Impl1<T> with Impl1::<T>::new())]
128/// struct Impl1<T>(T)
129/// where
130///     T: Default;
131///
132/// impl<T> Impl1<T>
133/// where
134///     T: Default,
135/// {
136///     fn new() -> Self {
137///         Self(Default::default())
138///     }
139/// }
140///
141/// fn build_container() {
142///   // Take note that these providers have to be constructed
143///   // with explicit types.
144///   let impl1_provider = Impl1Provider::<bool>::new();
145///   let container = container! {
146///       impl1 => impl1_provider,
147///   };
148///   let _bool_impl = container
149///       .resolve::<Impl1<bool>>("impl1")
150///       .expect("Should exist");
151/// }
152///
153/// # build_container();
154/// ```
155///
156/// If you need some form of constructor fn that takes arguments that are not injected, then you
157/// might be able to use the [`coi::Provide`] derive. If that doesn't fit your use case, you'll
158/// need to manually implement `Provide`.
159///
160/// [`coi::Provide`]: derive.Provide.html
161#[proc_macro_derive(Inject, attributes(coi))]
162pub fn inject_derive(input: TokenStream) -> TokenStream {
163    let input = parse_macro_input!(input as DeriveInput);
164    let cx = Ctxt::new();
165    let container = Container::from_ast(&cx, &input, true);
166    if let Err(e) = cx.check() {
167        return to_compile_errors(e).into();
168    }
169    let container = container.unwrap();
170
171    let has_generics = !input.generics.params.is_empty();
172    let generic_params = input.generics.params;
173    let generics = if has_generics {
174        quote! {
175            <#generic_params>
176        }
177    } else {
178        quote! {}
179    };
180
181    let coi = container.coi_path();
182    let where_clause = input
183        .generics
184        .where_clause
185        .map(|w| {
186            let t: Vec<_> = generic_params.iter().collect();
187            quote! { #w #(, #t: Send + Sync + 'static )* }
188        })
189        .unwrap_or_default();
190    if container.providers.is_empty() {
191        let ident = input.ident;
192        return quote! {
193            impl #generics #coi::Inject for #ident #generics #where_clause {}
194        }
195        .into();
196    }
197
198    let container_ident = format_ident!(
199        "{}",
200        if container.injected.is_empty() {
201            "_"
202        } else {
203            "container"
204        }
205    );
206    let (resolve, keys): (Vec<_>, Vec<_>) = container
207        .injected
208        .into_iter()
209        .map(|field| {
210            let ident = field.name;
211            let ty = field.ty;
212            let key = format!("{}", ident);
213            (
214                quote! {
215                    let #ident = #container_ident.resolve::<#ty>(#key)?;
216                },
217                key,
218            )
219        })
220        .unzip();
221    let input_ident = input.ident;
222
223    let dependencies_fn = if cfg!(feature = "debug") {
224        vec![quote! {
225            fn dependencies(&self) -> &'static[&'static str] {
226                &[
227                    #( #keys, )*
228                ]
229            }
230        }]
231    } else {
232        vec![]
233    };
234
235    let provider_fields = if has_generics {
236        let tys: Vec<_> = generic_params.iter().cloned().collect();
237        quote! {
238            (
239                #( ::std::marker::PhantomData<#tys> )*
240            )
241        }
242    } else {
243        quote! {}
244    };
245
246    let phantom_data: Vec<_> = generic_params
247        .iter()
248        .map(|_| quote! {::std::marker::PhantomData})
249        .collect();
250
251    let provider_impls = if !phantom_data.is_empty() {
252        container
253            .providers
254            .iter()
255            .map(|p| {
256                let provider = p.name_or(&input_ident);
257                let vis = &p.vis;
258                quote! {
259                    impl #generics #provider #generics #where_clause {
260                        #vis fn new() -> Self {
261                            Self(#( #phantom_data )*)
262                        }
263                    }
264                }
265            })
266            .collect()
267    } else {
268        vec![]
269    };
270
271    let constructed_provides: Vec<_> = container
272        .providers
273        .into_iter()
274        .map(|p| {
275            let provider = p.name_or(&input_ident);
276            let vis = p.vis;
277            let ty = p.ty;
278            let provides_with = p.with;
279
280            quote! {
281                #vis struct #provider #generics #provider_fields #where_clause;
282
283                impl #generics #coi::Provide for #provider #generics #where_clause {
284                    type Output = #ty;
285
286                    fn provide(
287                        &self,
288                        #container_ident: &#coi::Container,
289                    ) -> #coi::Result<::std::sync::Arc<Self::Output>> {
290                        #( #resolve )*
291                        Ok(::std::sync::Arc::new(#provides_with) as ::std::sync::Arc<#ty>)
292                    }
293
294                    #( #dependencies_fn )*
295                }
296            }
297        })
298        .collect();
299
300    let expanded = quote! {
301        impl #generics #coi::Inject for #input_ident #generics #where_clause {}
302
303        #( #provider_impls )*
304        #( #constructed_provides )*
305    };
306    TokenStream::from(expanded)
307}
308
309/// Generates an impl for `Provide` and also generates a "Provider" struct with its own
310/// `Provide` impl.
311///
312/// This derive proc macro impls `Provide` on the struct it modifies, and also processes #[coi(...)]
313/// attributes:
314/// - `#[coi(provides ...)]` - It takes the form
315/// ```rust,ignore
316/// #[coi(provides <vis> <ty> with <expr>)]
317/// ```
318///
319/// Multiple `provides` attributes are not allowed since this is for a specific `Provide` impl and
320/// not for the resolved type.
321///
322/// It generates a provider struct with visibility `<vis>`
323/// that impls `Provide` with an output type of `Arc<<ty>>`. It will construct `<ty>` with `<expr>`,
324/// and all params to `<expr>` must match the struct fields marked with `#[coi(inject)]` (see the
325/// next bullet item). `<vis>` must match the visibility of `<ty>` or you will get code that might
326/// not compile. If `<name>` is not provided, the struct name will be used and `Provider` will be
327/// appended to it.
328///
329/// ## Examples
330///
331/// Private trait and no dependencies
332/// ```rust
333/// use coi::{Inject, Provide};
334/// # use coi_derive::{Inject, Provide};
335/// trait Priv: Inject {}
336///
337/// #[derive(Inject)]
338/// # pub
339/// struct SimpleStruct {
340///     data: u32
341/// }
342///
343/// impl SimpleStruct {
344///     fn new(data: u32) -> Self {
345///         Self { data }
346///     }
347/// }
348///
349/// impl Priv for SimpleStruct {}
350///
351/// #[derive(Provide)]
352/// #[coi(provides dyn Priv with SimpleStruct::new(self.data))]
353/// struct SimpleStructProvider {
354///     data: u32,
355/// }
356///
357/// impl SimpleStructProvider {
358///     fn new(data: u32) -> Self {
359///         Self { data: 42 }
360///     }
361/// }
362/// ```
363#[proc_macro_derive(Provide, attributes(coi))]
364pub fn provide_derive(input: TokenStream) -> TokenStream {
365    let input = parse_macro_input!(input as DeriveInput);
366    let cx = Ctxt::new();
367    let container = Container::from_ast(&cx, &input, false);
368    if let Err(e) = cx.check() {
369        return to_compile_errors(e).into();
370    }
371    let container = container.unwrap();
372
373    let provider = input.ident.clone();
374    let has_generics = !input.generics.params.is_empty();
375    let generic_params = input.generics.params;
376    let generics = if has_generics {
377        quote! {
378            <#generic_params>
379        }
380    } else {
381        quote! {}
382    };
383    let where_clause = input
384        .generics
385        .where_clause
386        .map(|w| {
387            let t: Vec<_> = generic_params.iter().collect();
388            quote! { #w #(, #t: Send + Sync + 'static )* }
389        })
390        .unwrap_or_default();
391
392    let dependencies_fn = if cfg!(feature = "debug") {
393        vec![{
394            quote! {
395                fn dependencies(
396                    &self
397                ) -> &'static [&'static str] {
398                    &[]
399                }
400            }
401        }]
402    } else {
403        vec![]
404    };
405
406    let coi = container.coi_path();
407    let expanded: Vec<_> = container
408        .providers
409        .into_iter()
410        .map(|p| {
411            let ty = p.ty;
412            let provides_with = p.with;
413            quote! {
414                impl #generics #coi::Provide for #provider #generics #where_clause {
415                    type Output = #ty;
416
417                    fn provide(
418                        &self,
419                        _: &#coi::Container,
420                    ) -> #coi::Result<::std::sync::Arc<Self::Output>> {
421                        Ok(::std::sync::Arc::new(#provides_with) as ::std::sync::Arc<#ty>)
422                    }
423
424                    #( #dependencies_fn )*
425                }
426            }
427        })
428        .collect();
429    TokenStream::from(quote! {
430        #( #expanded )*
431    })
432}
433
434fn to_compile_errors(errors: Vec<Error>) -> proc_macro2::TokenStream {
435    let compile_errors = errors.iter().map(Error::to_compile_error);
436    quote!(#(#compile_errors)*)
437}