dependency_injector_derive/
lib.rs

1//! Derive macros for dependency-injector
2//!
3//! This crate provides the `#[derive(Inject)]` macro for automatic
4//! dependency injection at compile time.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use dependency_injector::{Container, Inject};
10//! use std::sync::Arc;
11//!
12//! #[derive(Clone)]
13//! struct Database {
14//!     url: String,
15//! }
16//!
17//! #[derive(Clone)]
18//! struct Cache {
19//!     size: usize,
20//! }
21//!
22//! #[derive(Inject)]
23//! struct UserService {
24//!     #[inject]
25//!     db: Arc<Database>,
26//!     #[inject]
27//!     cache: Arc<Cache>,
28//!     // Non-injected fields use Default
29//!     request_count: u64,
30//! }
31//!
32//! let container = Container::new();
33//! container.singleton(Database { url: "postgres://localhost".into() });
34//! container.singleton(Cache { size: 1024 });
35//!
36//! let service = UserService::from_container(&container).unwrap();
37//! ```
38
39use proc_macro::TokenStream;
40use quote::quote;
41use syn::{parse_macro_input, DeriveInput, Data, Fields, Type, Attribute};
42
43/// Derive macro for automatic dependency injection.
44///
45/// Generates a `from_container()` method that resolves dependencies
46/// from a `Container` instance.
47///
48/// # Attributes
49///
50/// - `#[inject]` - Mark a field for injection. The field type must be `Arc<T>`.
51/// - `#[inject(optional)]` - Mark a field as optional injection. Uses `Option<Arc<T>>`.
52///
53/// # Generated Methods
54///
55/// - `from_container(container: &Container) -> Result<Self, DiError>` - Creates an instance
56///   by resolving all `#[inject]` fields from the container.
57///
58/// # Example
59///
60/// ```rust,ignore
61/// #[derive(Inject)]
62/// struct MyService {
63///     #[inject]
64///     db: Arc<Database>,
65///     #[inject(optional)]
66///     cache: Option<Arc<Cache>>,
67///     // Fields without #[inject] use Default::default()
68///     counter: u64,
69/// }
70/// ```
71#[proc_macro_derive(Inject, attributes(inject))]
72pub fn derive_inject(input: TokenStream) -> TokenStream {
73    let input = parse_macro_input!(input as DeriveInput);
74
75    let name = &input.ident;
76    let generics = &input.generics;
77    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
78
79    // Only support structs with named fields
80    let fields = match &input.data {
81        Data::Struct(data) => match &data.fields {
82            Fields::Named(fields) => &fields.named,
83            _ => {
84                return syn::Error::new_spanned(
85                    &input,
86                    "Inject can only be derived for structs with named fields"
87                )
88                .to_compile_error()
89                .into();
90            }
91        },
92        _ => {
93            return syn::Error::new_spanned(
94                &input,
95                "Inject can only be derived for structs"
96            )
97            .to_compile_error()
98            .into();
99        }
100    };
101
102    // Parse fields and generate initialization code
103    let mut field_inits = Vec::new();
104
105    for field in fields.iter() {
106        let field_name = field.ident.as_ref().unwrap();
107        let field_type = &field.ty;
108
109        let inject_attr = find_inject_attr(&field.attrs);
110
111        match inject_attr {
112            Some(InjectAttr::Required) => {
113                // Extract inner type from Arc<T>
114                if let Some(inner_type) = extract_arc_inner_type(field_type) {
115                    field_inits.push(quote! {
116                        #field_name: container.get::<#inner_type>()?
117                    });
118                } else {
119                    return syn::Error::new_spanned(
120                        field_type,
121                        "Fields marked with #[inject] must have type Arc<T>"
122                    )
123                    .to_compile_error()
124                    .into();
125                }
126            }
127            Some(InjectAttr::Optional) => {
128                // Extract inner type from Option<Arc<T>>
129                if let Some(inner_type) = extract_option_arc_inner_type(field_type) {
130                    field_inits.push(quote! {
131                        #field_name: container.try_get::<#inner_type>()
132                    });
133                } else {
134                    return syn::Error::new_spanned(
135                        field_type,
136                        "Fields marked with #[inject(optional)] must have type Option<Arc<T>>"
137                    )
138                    .to_compile_error()
139                    .into();
140                }
141            }
142            None => {
143                // Non-injected field - use Default
144                field_inits.push(quote! {
145                    #field_name: ::std::default::Default::default()
146                });
147            }
148        }
149    }
150
151    // Generate the implementation
152    let expanded = quote! {
153        impl #impl_generics #name #ty_generics #where_clause {
154            /// Create an instance by resolving dependencies from a container.
155            ///
156            /// All fields marked with `#[inject]` will be resolved from the container.
157            /// Fields not marked with `#[inject]` will use `Default::default()`.
158            pub fn from_container(
159                container: &::dependency_injector::Container
160            ) -> ::dependency_injector::Result<Self> {
161                Ok(Self {
162                    #(#field_inits),*
163                })
164            }
165        }
166    };
167
168    TokenStream::from(expanded)
169}
170
171/// Types of inject attributes
172enum InjectAttr {
173    Required,
174    Optional,
175}
176
177/// Find and parse the #[inject] attribute
178fn find_inject_attr(attrs: &[Attribute]) -> Option<InjectAttr> {
179    for attr in attrs {
180        if attr.path().is_ident("inject") {
181            // Check if it has (optional) argument
182            if attr.meta.require_path_only().is_ok() {
183                return Some(InjectAttr::Required);
184            }
185
186            // Parse inject(optional)
187            if let Ok(nested) = attr.parse_args::<syn::Ident>() {
188                if nested == "optional" {
189                    return Some(InjectAttr::Optional);
190                }
191            }
192
193            // Default to required
194            return Some(InjectAttr::Required);
195        }
196    }
197    None
198}
199
200/// Extract T from Arc<T>
201fn extract_arc_inner_type(ty: &Type) -> Option<&Type> {
202    if let Type::Path(type_path) = ty {
203        let segment = type_path.path.segments.last()?;
204        if segment.ident == "Arc" {
205            if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
206                if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
207                    return Some(inner);
208                }
209            }
210        }
211    }
212    None
213}
214
215/// Extract T from Option<Arc<T>>
216fn extract_option_arc_inner_type(ty: &Type) -> Option<&Type> {
217    if let Type::Path(type_path) = ty {
218        let segment = type_path.path.segments.last()?;
219        if segment.ident == "Option" {
220            if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
221                if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
222                    return extract_arc_inner_type(inner);
223                }
224            }
225        }
226    }
227    None
228}
229