use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, DeriveInput, Data, Fields, Type, Attribute};
#[proc_macro_derive(Inject, attributes(inject))]
pub fn derive_inject(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let generics = &input.generics;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let fields = match &input.data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => &fields.named,
_ => {
return syn::Error::new_spanned(
&input,
"Inject can only be derived for structs with named fields"
)
.to_compile_error()
.into();
}
},
_ => {
return syn::Error::new_spanned(
&input,
"Inject can only be derived for structs"
)
.to_compile_error()
.into();
}
};
let mut field_inits = Vec::new();
for field in fields.iter() {
let field_name = field.ident.as_ref().unwrap();
let field_type = &field.ty;
let inject_attr = find_inject_attr(&field.attrs);
match inject_attr {
Some(InjectAttr::Required) => {
if let Some(inner_type) = extract_arc_inner_type(field_type) {
field_inits.push(quote! {
#field_name: container.get::<#inner_type>()?
});
} else {
return syn::Error::new_spanned(
field_type,
"Fields marked with #[inject] must have type Arc<T>"
)
.to_compile_error()
.into();
}
}
Some(InjectAttr::Optional) => {
if let Some(inner_type) = extract_option_arc_inner_type(field_type) {
field_inits.push(quote! {
#field_name: container.try_get::<#inner_type>()
});
} else {
return syn::Error::new_spanned(
field_type,
"Fields marked with #[inject(optional)] must have type Option<Arc<T>>"
)
.to_compile_error()
.into();
}
}
None => {
field_inits.push(quote! {
#field_name: ::std::default::Default::default()
});
}
}
}
let expanded = quote! {
impl #impl_generics #name #ty_generics #where_clause {
pub fn from_container(
container: &::dependency_injector::Container
) -> ::dependency_injector::Result<Self> {
Ok(Self {
#(#field_inits),*
})
}
}
};
TokenStream::from(expanded)
}
enum InjectAttr {
Required,
Optional,
}
fn find_inject_attr(attrs: &[Attribute]) -> Option<InjectAttr> {
for attr in attrs {
if attr.path().is_ident("inject") {
if attr.meta.require_path_only().is_ok() {
return Some(InjectAttr::Required);
}
if let Ok(nested) = attr.parse_args::<syn::Ident>() {
if nested == "optional" {
return Some(InjectAttr::Optional);
}
}
return Some(InjectAttr::Required);
}
}
None
}
fn extract_arc_inner_type(ty: &Type) -> Option<&Type> {
if let Type::Path(type_path) = ty {
let segment = type_path.path.segments.last()?;
if segment.ident == "Arc" {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
return Some(inner);
}
}
}
}
None
}
fn extract_option_arc_inner_type(ty: &Type) -> Option<&Type> {
if let Type::Path(type_path) = ty {
let segment = type_path.path.segments.last()?;
if segment.ident == "Option" {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
return extract_arc_inner_type(inner);
}
}
}
}
None
}