dependency_injector_derive/lib.rs
1//! Derive macros for dependency-injector
2//!
3//! This crate provides the `#[derive(Inject)]` macro for automatic
4//! dependency injection at compile time.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use dependency_injector::{Container, Inject};
10//! use std::sync::Arc;
11//!
12//! #[derive(Clone)]
13//! struct Database {
14//! url: String,
15//! }
16//!
17//! #[derive(Clone)]
18//! struct Cache {
19//! size: usize,
20//! }
21//!
22//! #[derive(Inject)]
23//! struct UserService {
24//! #[inject]
25//! db: Arc<Database>,
26//! #[inject]
27//! cache: Arc<Cache>,
28//! // Non-injected fields use Default
29//! request_count: u64,
30//! }
31//!
32//! let container = Container::new();
33//! container.singleton(Database { url: "postgres://localhost".into() });
34//! container.singleton(Cache { size: 1024 });
35//!
36//! let service = UserService::from_container(&container).unwrap();
37//! ```
38
39use proc_macro::TokenStream;
40use quote::quote;
41use syn::{parse_macro_input, DeriveInput, Data, Fields, Type, Attribute};
42
43/// Derive macro for automatic dependency injection.
44///
45/// Generates a `from_container()` method that resolves dependencies
46/// from a `Container` instance.
47///
48/// # Attributes
49///
50/// - `#[inject]` - Mark a field for injection. The field type must be `Arc<T>`.
51/// - `#[inject(optional)]` - Mark a field as optional injection. Uses `Option<Arc<T>>`.
52///
53/// # Generated Methods
54///
55/// - `from_container(container: &Container) -> Result<Self, DiError>` - Creates an instance
56/// by resolving all `#[inject]` fields from the container.
57///
58/// # Example
59///
60/// ```rust,ignore
61/// #[derive(Inject)]
62/// struct MyService {
63/// #[inject]
64/// db: Arc<Database>,
65/// #[inject(optional)]
66/// cache: Option<Arc<Cache>>,
67/// // Fields without #[inject] use Default::default()
68/// counter: u64,
69/// }
70/// ```
71#[proc_macro_derive(Inject, attributes(inject))]
72pub fn derive_inject(input: TokenStream) -> TokenStream {
73 let input = parse_macro_input!(input as DeriveInput);
74
75 let name = &input.ident;
76 let generics = &input.generics;
77 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
78
79 // Only support structs with named fields
80 let fields = match &input.data {
81 Data::Struct(data) => match &data.fields {
82 Fields::Named(fields) => &fields.named,
83 _ => {
84 return syn::Error::new_spanned(
85 &input,
86 "Inject can only be derived for structs with named fields"
87 )
88 .to_compile_error()
89 .into();
90 }
91 },
92 _ => {
93 return syn::Error::new_spanned(
94 &input,
95 "Inject can only be derived for structs"
96 )
97 .to_compile_error()
98 .into();
99 }
100 };
101
102 // Parse fields and generate initialization code
103 let mut field_inits = Vec::new();
104
105 for field in fields.iter() {
106 let field_name = field.ident.as_ref().unwrap();
107 let field_type = &field.ty;
108
109 let inject_attr = find_inject_attr(&field.attrs);
110
111 match inject_attr {
112 Some(InjectAttr::Required) => {
113 // Extract inner type from Arc<T>
114 if let Some(inner_type) = extract_arc_inner_type(field_type) {
115 field_inits.push(quote! {
116 #field_name: container.get::<#inner_type>()?
117 });
118 } else {
119 return syn::Error::new_spanned(
120 field_type,
121 "Fields marked with #[inject] must have type Arc<T>"
122 )
123 .to_compile_error()
124 .into();
125 }
126 }
127 Some(InjectAttr::Optional) => {
128 // Extract inner type from Option<Arc<T>>
129 if let Some(inner_type) = extract_option_arc_inner_type(field_type) {
130 field_inits.push(quote! {
131 #field_name: container.try_get::<#inner_type>()
132 });
133 } else {
134 return syn::Error::new_spanned(
135 field_type,
136 "Fields marked with #[inject(optional)] must have type Option<Arc<T>>"
137 )
138 .to_compile_error()
139 .into();
140 }
141 }
142 None => {
143 // Non-injected field - use Default
144 field_inits.push(quote! {
145 #field_name: ::std::default::Default::default()
146 });
147 }
148 }
149 }
150
151 // Generate the implementation
152 let expanded = quote! {
153 impl #impl_generics #name #ty_generics #where_clause {
154 /// Create an instance by resolving dependencies from a container.
155 ///
156 /// All fields marked with `#[inject]` will be resolved from the container.
157 /// Fields not marked with `#[inject]` will use `Default::default()`.
158 pub fn from_container(
159 container: &::dependency_injector::Container
160 ) -> ::dependency_injector::Result<Self> {
161 Ok(Self {
162 #(#field_inits),*
163 })
164 }
165 }
166 };
167
168 TokenStream::from(expanded)
169}
170
171/// Types of inject attributes
172enum InjectAttr {
173 Required,
174 Optional,
175}
176
177/// Find and parse the #[inject] attribute
178fn find_inject_attr(attrs: &[Attribute]) -> Option<InjectAttr> {
179 for attr in attrs {
180 if attr.path().is_ident("inject") {
181 // Check if it has (optional) argument
182 if attr.meta.require_path_only().is_ok() {
183 return Some(InjectAttr::Required);
184 }
185
186 // Parse inject(optional)
187 if let Ok(nested) = attr.parse_args::<syn::Ident>() {
188 if nested == "optional" {
189 return Some(InjectAttr::Optional);
190 }
191 }
192
193 // Default to required
194 return Some(InjectAttr::Required);
195 }
196 }
197 None
198}
199
200/// Extract T from Arc<T>
201fn extract_arc_inner_type(ty: &Type) -> Option<&Type> {
202 if let Type::Path(type_path) = ty {
203 let segment = type_path.path.segments.last()?;
204 if segment.ident == "Arc" {
205 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
206 if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
207 return Some(inner);
208 }
209 }
210 }
211 }
212 None
213}
214
215/// Extract T from Option<Arc<T>>
216fn extract_option_arc_inner_type(ty: &Type) -> Option<&Type> {
217 if let Type::Path(type_path) = ty {
218 let segment = type_path.path.segments.last()?;
219 if segment.ident == "Option" {
220 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
221 if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
222 return extract_arc_inner_type(inner);
223 }
224 }
225 }
226 }
227 None
228}
229