use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{
FnArg, GenericArgument, ImplItem, ImplItemFn, Item, ItemFn, ItemImpl, ItemStruct, LitStr, Pat,
PathArguments, ReturnType, Signature, Type, parse_macro_input, visit::Visit,
};
fn auto_di_path() -> proc_macro2::TokenStream {
match proc_macro_crate::crate_name("auto-di") {
Ok(proc_macro_crate::FoundCrate::Itself) => quote!(::auto_di),
Ok(proc_macro_crate::FoundCrate::Name(name)) => {
let name = format_ident!("{name}");
quote!(::#name)
}
Err(_) => quote!(::auto_di),
}
}
#[derive(Default, Clone)]
struct ProviderOptions {
name: Option<String>,
primary: bool,
scope: Option<String>,
eager: bool,
blocking: bool,
profile: Option<String>,
condition_key: Option<String>,
condition_value: Option<String>,
post_construct: Option<String>,
pre_destroy: Option<String>,
}
pub(crate) fn singleton(attribute: TokenStream, item: TokenStream) -> TokenStream {
let options = match parse_provider_options(attribute) {
Ok(options) => options,
Err(error) => return error.to_compile_error().into(),
};
let item = parse_macro_input!(item as Item);
let expanded = match item {
Item::Fn(function) => expand_function(function, &options),
Item::Impl(item_impl) => expand_impl(item_impl, &options),
other => Err(syn::Error::new_spanned(
other,
"#[singleton] can only be used on a function or an inherent impl block",
)),
};
match expanded {
Ok(tokens) => tokens.into(),
Err(error) => error.to_compile_error().into(),
}
}
fn parse_provider_options(attribute: TokenStream) -> syn::Result<ProviderOptions> {
parse_provider_options2(attribute.into())
}
fn parse_provider_options2(attribute: proc_macro2::TokenStream) -> syn::Result<ProviderOptions> {
use syn::parse::Parser;
let mut options = ProviderOptions::default();
let parser = syn::meta::parser(|meta| {
if meta.path.is_ident("primary") {
options.primary = true;
return Ok(());
}
if meta.path.is_ident("eager") {
options.eager = true;
return Ok(());
}
if meta.path.is_ident("blocking") {
options.blocking = true;
return Ok(());
}
let value = meta.value()?.parse::<LitStr>()?.value();
if meta.path.is_ident("name") {
options.name = Some(value);
} else if meta.path.is_ident("scope") {
options.scope = Some(value);
} else if meta.path.is_ident("profile") {
options.profile = Some(value);
} else if meta.path.is_ident("condition") {
let (key, expected) = value
.split_once('=')
.map_or((value.as_str(), None), |(k, v)| (k, Some(v)));
options.condition_key = Some(key.to_owned());
options.condition_value = expected.map(str::to_owned);
} else if meta.path.is_ident("post_construct") {
options.post_construct = Some(value);
} else if meta.path.is_ident("pre_destroy") {
options.pre_destroy = Some(value);
} else {
return Err(meta.error("unknown provider option"));
}
Ok(())
});
parser.parse2(attribute)?;
if let Some(scope) = &options.scope {
if !matches!(scope.as_str(), "singleton" | "prototype" | "request") {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"scope must be singleton, prototype, or request",
));
}
}
Ok(options)
}
pub(crate) fn qualifier(_attribute: TokenStream, item: TokenStream) -> TokenStream {
item
}
pub(crate) fn application(_attribute: TokenStream, item: TokenStream) -> TokenStream {
let auto_di = auto_di_path();
let mut function = parse_macro_input!(item as ItemFn);
if function.sig.asyncness.is_none() {
return syn::Error::new_spanned(&function.sig, "#[application] requires async fn")
.to_compile_error()
.into();
}
let body = function.block;
function.sig.asyncness = None;
function.block = Box::new(syn::parse_quote!({
let __runtime = #auto_di::__private::tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("auto-di failed to create the Tokio runtime");
__runtime.block_on(async move {
let __container = #auto_di::global_container()?;
__container.validate().await?;
let __result = (async move #body).await;
__container.shutdown().await?;
__result
})
}));
quote!(#function).into()
}
pub(crate) fn configuration_properties(attribute: TokenStream, item: TokenStream) -> TokenStream {
let prefix = parse_macro_input!(attribute as LitStr).value();
let structure = parse_macro_input!(item as ItemStruct);
match expand_configuration_properties(prefix, structure) {
Ok(tokens) => tokens.into(),
Err(error) => error.to_compile_error().into(),
}
}
fn expand_configuration_properties(
prefix: String,
structure: ItemStruct,
) -> syn::Result<proc_macro2::TokenStream> {
let auto_di = auto_di_path();
let ident = &structure.ident;
let fields = match &structure.fields {
syn::Fields::Named(fields) => &fields.named,
_ => {
return Err(syn::Error::new_spanned(
&structure,
"configuration properties require named fields",
));
}
};
let bindings = fields.iter().map(|field| {
let name = field.ident.as_ref().expect("named field");
let ty = &field.ty;
let key = format!("{}_{}", prefix, name).replace(['.', '-'], "_").to_uppercase();
quote! {
#name: ::std::env::var(#key)
.map_err(|error| #auto_di::DiError::Configuration { key: #key.into(), message: error.to_string() })?
.parse::<#ty>()
.map_err(|_| #auto_di::DiError::Configuration { key: #key.into(), message: "value could not be parsed".into() })?
}
});
let provider_name = format_ident!("configuration_properties_{ident}");
let invocation = quote! { <#ident as #auto_di::ConfigurationProperties>::from_environment()? };
let registration = registration(
&provider_name,
&syn::parse_quote!(#ident),
&[],
&[],
&[],
invocation,
);
Ok(quote! {
#structure
impl #auto_di::ConfigurationProperties for #ident {
fn from_environment() -> ::std::result::Result<Self, #auto_di::DiError> {
Ok(Self { #(#bindings),* })
}
}
#registration
})
}
pub(crate) fn configuration(_attribute: TokenStream, item: TokenStream) -> TokenStream {
let item_impl = parse_macro_input!(item as ItemImpl);
match expand_configuration(item_impl) {
Ok(tokens) => tokens.into(),
Err(error) => error.to_compile_error().into(),
}
}
pub(crate) fn provider(attribute: TokenStream, item: TokenStream) -> TokenStream {
singleton(attribute, item)
}
fn expand_function(
mut function: ItemFn,
options: &ProviderOptions,
) -> syn::Result<proc_macro2::TokenStream> {
let function_name = function.sig.ident.clone();
let (output_type, fallible) = provided_output(&function.sig.output)?;
let (argument_names, dependency_types, qualifiers) = dependencies(&mut function.sig)?;
let is_async = function.sig.asyncness.is_some();
let invocation = if is_async {
quote! { #function_name(#(#argument_names),*).await }
} else {
quote! { #function_name(#(#argument_names),*) }
};
let invocation = blocking_invocation(invocation, options.blocking, is_async, &output_type)?;
let invocation = factory_invocation(invocation, fallible, &output_type);
let registration = registration_with_options(
&function_name,
&output_type,
&argument_names,
&dependency_types,
&qualifiers,
invocation,
options,
);
Ok(quote! {
#function
#registration
})
}
fn expand_impl(
mut item_impl: ItemImpl,
options: &ProviderOptions,
) -> syn::Result<proc_macro2::TokenStream> {
if item_impl.trait_.is_some() {
return Err(syn::Error::new_spanned(
&item_impl,
"#[singleton] requires an inherent impl, not a trait impl",
));
}
if !item_impl.generics.params.is_empty() {
return Err(syn::Error::new_spanned(
&item_impl.generics,
"generic singleton impl blocks are not supported yet",
));
}
reject_direct_provider_calls(&item_impl)?;
let constructor_index = item_impl
.items
.iter()
.position(|item| match item {
ImplItem::Fn(function) => function.sig.ident == "new",
_ => false,
})
.ok_or_else(|| {
syn::Error::new_spanned(
&item_impl.self_ty,
"a #[singleton] impl must contain a new(...) -> Self constructor",
)
})?;
let output_type = item_impl.self_ty.as_ref().clone();
let constructor = match &mut item_impl.items[constructor_index] {
ImplItem::Fn(function) => function,
_ => unreachable!(),
};
let fallible_constructor = validate_impl_constructor(constructor)?;
let (argument_names, dependency_types, qualifiers) = dependencies(&mut constructor.sig)?;
let type_ident = singleton_type_ident(&output_type)?;
let constructor_name = constructor.sig.ident.clone();
let is_async = constructor.sig.asyncness.is_some();
let invocation = if is_async {
quote! { <#output_type>::#constructor_name(#(#argument_names),*).await }
} else {
quote! { <#output_type>::#constructor_name(#(#argument_names),*) }
};
let invocation = blocking_invocation(invocation, options.blocking, is_async, &output_type)?;
let invocation = factory_invocation(invocation, fallible_constructor, &output_type);
let registration = registration_with_options(
type_ident,
&output_type,
&argument_names,
&dependency_types,
&qualifiers,
invocation,
options,
);
let mut provider_registrations = Vec::new();
for item in &mut item_impl.items {
let ImplItem::Fn(method) = item else { continue };
let Some(provider_attribute) = method
.attrs
.iter()
.find(|attribute| is_provider_attribute(attribute))
else {
continue;
};
let provider_options = match &provider_attribute.meta {
syn::Meta::Path(_) => ProviderOptions::default(),
syn::Meta::List(list) => parse_provider_options2(list.tokens.clone())?,
syn::Meta::NameValue(_) => {
return Err(syn::Error::new_spanned(
provider_attribute,
"use #[provider(...)]",
));
}
};
method
.attrs
.retain(|attribute| !is_provider_attribute(attribute));
let (provider_type, fallible) = provided_output(&method.sig.output)?;
let (has_receiver, provider_arguments, provider_dependencies, provider_qualifiers) =
provider_dependencies(&mut method.sig)?;
let method_name = method.sig.ident.clone();
let provider_registration_name = format_ident!("provider_{type_ident}_{method_name}");
let provider_invocation = if has_receiver {
quote! { owner.#method_name(#(#provider_arguments),*) }
} else {
quote! { <#output_type>::#method_name(#(#provider_arguments),*) }
};
let is_async = method.sig.asyncness.is_some();
let provider_invocation = if is_async {
quote! { #provider_invocation.await }
} else {
provider_invocation
};
let provider_invocation = blocking_invocation(
provider_invocation,
provider_options.blocking,
is_async,
&provider_type,
)?;
let provider_invocation = factory_invocation(provider_invocation, fallible, &provider_type);
let prelude = has_receiver.then(|| {
quote! {
let owner: ::std::sync::Arc<#output_type> =
container.resolve_dependency::<#output_type>(&context).await?;
}
});
provider_registrations.push(registration_with_prelude(
&provider_registration_name,
&provider_type,
&provider_arguments,
&provider_dependencies,
&provider_qualifiers,
prelude.unwrap_or_default(),
provider_invocation,
&provider_options,
));
}
Ok(quote! {
#item_impl
#registration
#(#provider_registrations)*
})
}
fn reject_direct_provider_calls(item_impl: &ItemImpl) -> syn::Result<()> {
let provider_names = item_impl
.items
.iter()
.filter_map(|item| match item {
ImplItem::Fn(method) if method.attrs.iter().any(is_provider_attribute) => {
Some(method.sig.ident.clone())
}
_ => None,
})
.collect::<std::collections::HashSet<_>>();
if provider_names.is_empty() {
return Ok(());
}
for item in &item_impl.items {
if let ImplItem::Fn(method) = item {
let mut visitor = DirectProviderCallVisitor {
provider_names: &provider_names,
error: None,
};
visitor.visit_block(&method.block);
if let Some(error) = visitor.error {
return Err(error);
}
}
}
Ok(())
}
struct DirectProviderCallVisitor<'a> {
provider_names: &'a std::collections::HashSet<syn::Ident>,
error: Option<syn::Error>,
}
impl<'ast> Visit<'ast> for DirectProviderCallVisitor<'_> {
fn visit_expr_method_call(&mut self, expression: &'ast syn::ExprMethodCall) {
if self.provider_names.contains(&expression.method)
&& matches!(expression.receiver.as_ref(), syn::Expr::Path(path) if path.path.is_ident("self"))
{
self.error = Some(syn::Error::new_spanned(
expression,
"direct provider calls bypass dependency injection; inject the provided type as a parameter",
));
return;
}
syn::visit::visit_expr_method_call(self, expression);
}
fn visit_expr_call(&mut self, expression: &'ast syn::ExprCall) {
if let syn::Expr::Path(path) = expression.func.as_ref()
&& path.path.segments.len() > 1
&& path
.path
.segments
.last()
.is_some_and(|segment| self.provider_names.contains(&segment.ident))
{
self.error = Some(syn::Error::new_spanned(
expression,
"direct provider calls bypass dependency injection; inject the provided type as a parameter",
));
return;
}
syn::visit::visit_expr_call(self, expression);
}
}
fn expand_configuration(mut item_impl: ItemImpl) -> syn::Result<proc_macro2::TokenStream> {
if item_impl.trait_.is_some() || !item_impl.generics.params.is_empty() {
return Err(syn::Error::new_spanned(
&item_impl,
"#[configuration] requires a non-generic inherent impl",
));
}
reject_direct_provider_calls(&item_impl)?;
let configuration_type = item_impl.self_ty.as_ref().clone();
let configuration_ident = singleton_type_ident(&configuration_type)?.clone();
let configuration_registration = registration(
&format_ident!("configuration_{configuration_ident}"),
&configuration_type,
&[],
&[],
&[],
quote! { <#configuration_type as ::std::default::Default>::default() },
);
let mut registrations = Vec::new();
for item in &mut item_impl.items {
let ImplItem::Fn(method) = item else { continue };
let Some(provider_attribute) = method
.attrs
.iter()
.find(|attribute| is_provider_attribute(attribute))
else {
continue;
};
let provider_options = match &provider_attribute.meta {
syn::Meta::Path(_) => ProviderOptions::default(),
syn::Meta::List(list) => parse_provider_options2(list.tokens.clone())?,
syn::Meta::NameValue(_) => {
return Err(syn::Error::new_spanned(
provider_attribute,
"use #[provider(...)]",
));
}
};
method
.attrs
.retain(|attribute| !is_provider_attribute(attribute));
let (output_type, fallible) = provided_output(&method.sig.output)?;
let (has_receiver, argument_names, dependency_types, qualifiers) =
provider_dependencies(&mut method.sig)?;
let method_name = method.sig.ident.clone();
let registration_name = format_ident!("provider_{configuration_ident}_{method_name}");
let invocation = if has_receiver {
quote! { configuration.#method_name(#(#argument_names),*) }
} else {
quote! { <#configuration_type>::#method_name(#(#argument_names),*) }
};
let is_async = method.sig.asyncness.is_some();
let invocation = if is_async {
quote! { #invocation.await }
} else {
invocation
};
let invocation = blocking_invocation(
invocation,
provider_options.blocking,
is_async,
&output_type,
)?;
let invocation = factory_invocation(invocation, fallible, &output_type);
let prelude = has_receiver.then(|| {
quote! {
let configuration: ::std::sync::Arc<#configuration_type> =
container.resolve_dependency::<#configuration_type>(&context).await?;
}
});
registrations.push(registration_with_prelude(
®istration_name,
&output_type,
&argument_names,
&dependency_types,
&qualifiers,
prelude.unwrap_or_default(),
invocation,
&provider_options,
));
}
if registrations.is_empty() {
return Err(syn::Error::new_spanned(
&item_impl,
"#[configuration] must contain at least one #[provider] method",
));
}
Ok(quote! {
#item_impl
#configuration_registration
#(#registrations)*
})
}
fn validate_impl_constructor(constructor: &ImplItemFn) -> syn::Result<bool> {
if constructor.sig.receiver().is_some() {
return Err(syn::Error::new_spanned(
&constructor.sig,
"the singleton new constructor must be an associated function",
));
}
match &constructor.sig.output {
ReturnType::Type(_, ty) if is_self_type(ty) => Ok(false),
ReturnType::Type(_, ty) if result_ok_type(ty).is_some_and(is_self_type) => Ok(true),
_ => Err(syn::Error::new_spanned(
&constructor.sig.output,
"the singleton new constructor must return Self or Result<Self, E>",
)),
}
}
fn provided_output(output: &ReturnType) -> syn::Result<(Type, bool)> {
let ReturnType::Type(_, ty) = output else {
return Err(syn::Error::new_spanned(
output,
"a provider must return a value",
));
};
if let Some(inner) = result_ok_type(ty) {
Ok((inner.clone(), true))
} else {
Ok((ty.as_ref().clone(), false))
}
}
fn result_ok_type(ty: &Type) -> Option<&Type> {
let Type::Path(path) = ty else { return None };
let segment = path.path.segments.last()?;
if segment.ident != "Result" {
return None;
}
let PathArguments::AngleBracketed(arguments) = &segment.arguments else {
return None;
};
match arguments.args.first() {
Some(GenericArgument::Type(ok)) => Some(ok),
_ => None,
}
}
fn is_self_type(ty: &Type) -> bool {
matches!(ty, Type::Path(path) if path.path.is_ident("Self"))
}
fn factory_invocation(
invocation: proc_macro2::TokenStream,
fallible: bool,
output_type: &Type,
) -> proc_macro2::TokenStream {
let auto_di = auto_di_path();
if fallible {
quote! {
#auto_di::__private::factory_result(
#invocation,
::std::any::type_name::<#output_type>(),
)?
}
} else {
invocation
}
}
fn blocking_invocation(
invocation: proc_macro2::TokenStream,
blocking: bool,
is_async: bool,
output_type: &Type,
) -> syn::Result<proc_macro2::TokenStream> {
if !blocking {
return Ok(invocation);
}
if is_async {
return Err(syn::Error::new_spanned(
output_type,
"blocking can only be used with synchronous providers",
));
}
let auto_di = auto_di_path();
Ok(quote! {
#auto_di::__private::tokio::task::spawn_blocking(move || #invocation)
.await
.map_err(|error| #auto_di::DiError::Factory {
provider: ::std::any::type_name::<#output_type>(),
message: error.to_string(),
})?
})
}
fn dependencies(
signature: &mut Signature,
) -> syn::Result<(Vec<syn::Ident>, Vec<Type>, Vec<Option<String>>)> {
let mut argument_names = Vec::new();
let mut dependency_types = Vec::new();
let mut qualifiers = Vec::new();
for input in &mut signature.inputs {
let FnArg::Typed(argument) = input else {
return Err(syn::Error::new_spanned(
input,
"singleton constructors cannot have a self receiver",
));
};
let Pat::Ident(pattern) = argument.pat.as_ref() else {
return Err(syn::Error::new_spanned(
&argument.pat,
"use a simple parameter name for an injected dependency",
));
};
argument_names.push(pattern.ident.clone());
validate_dependency_type(argument.ty.as_ref())?;
dependency_types.push(argument.ty.as_ref().clone());
qualifiers.push(take_qualifier(&mut argument.attrs)?);
}
Ok((argument_names, dependency_types, qualifiers))
}
fn provider_dependencies(
signature: &mut Signature,
) -> syn::Result<(bool, Vec<syn::Ident>, Vec<Type>, Vec<Option<String>>)> {
let mut has_receiver = false;
let mut names = Vec::new();
let mut types = Vec::new();
let mut qualifiers = Vec::new();
for input in &mut signature.inputs {
match input {
FnArg::Receiver(receiver) => {
if has_receiver || receiver.reference.is_none() || receiver.mutability.is_some() {
return Err(syn::Error::new_spanned(
receiver,
"a #[provider] method only supports an immutable &self receiver",
));
}
has_receiver = true;
}
FnArg::Typed(argument) => {
let Pat::Ident(pattern) = argument.pat.as_ref() else {
return Err(syn::Error::new_spanned(
&argument.pat,
"use a simple parameter name for an injected dependency",
));
};
names.push(pattern.ident.clone());
validate_dependency_type(argument.ty.as_ref())?;
types.push(argument.ty.as_ref().clone());
qualifiers.push(take_qualifier(&mut argument.attrs)?);
}
}
}
Ok((has_receiver, names, types, qualifiers))
}
fn is_provider_attribute(attribute: &syn::Attribute) -> bool {
attribute.path().is_ident("provider")
}
fn take_qualifier(attributes: &mut Vec<syn::Attribute>) -> syn::Result<Option<String>> {
let mut value = None;
for attribute in attributes
.iter()
.filter(|attr| attr.path().is_ident("qualifier"))
{
if value.is_some() {
return Err(syn::Error::new_spanned(
attribute,
"only one qualifier is allowed",
));
}
value = Some(attribute.parse_args::<LitStr>()?.value());
}
attributes.retain(|attr| !attr.path().is_ident("qualifier"));
Ok(value)
}
fn registration(
name: &syn::Ident,
output_type: &Type,
argument_names: &[syn::Ident],
dependency_types: &[Type],
qualifiers: &[Option<String>],
invocation: proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
registration_with_options(
name,
output_type,
argument_names,
dependency_types,
qualifiers,
invocation,
&ProviderOptions::default(),
)
}
fn registration_with_options(
name: &syn::Ident,
output_type: &Type,
argument_names: &[syn::Ident],
dependency_types: &[Type],
qualifiers: &[Option<String>],
invocation: proc_macro2::TokenStream,
options: &ProviderOptions,
) -> proc_macro2::TokenStream {
registration_with_prelude(
name,
output_type,
argument_names,
dependency_types,
qualifiers,
quote! {},
invocation,
options,
)
}
fn registration_with_prelude(
name: &syn::Ident,
output_type: &Type,
argument_names: &[syn::Ident],
dependency_types: &[Type],
qualifiers: &[Option<String>],
prelude: proc_macro2::TokenStream,
invocation: proc_macro2::TokenStream,
options: &ProviderOptions,
) -> proc_macro2::TokenStream {
let auto_di = auto_di_path();
let factory_name = format_ident!("__di_factory_{name}");
let type_id_name = format_ident!("__di_type_id_{name}");
let type_name_name = format_ident!("__di_type_name_{name}");
let destroy_name = format_ident!("__di_destroy_{name}");
let option_tokens = |value: Option<&str>| match value {
Some(value) => quote!(Some(#value)),
None => quote!(None),
};
let provider_name = option_tokens(options.name.as_deref());
let primary = options.primary;
let eager = options.eager;
let profile = option_tokens(options.profile.as_deref());
let condition_key = option_tokens(options.condition_key.as_deref());
let condition_value = option_tokens(options.condition_value.as_deref());
let scope = match options.scope.as_deref().unwrap_or("singleton") {
"prototype" => quote!(#auto_di::Scope::Prototype),
"request" => quote!(#auto_di::Scope::Request),
_ => quote!(#auto_di::Scope::Singleton),
};
let post_construct = options.post_construct.as_ref().map(|method| {
let method = format_ident!("{method}");
quote! {
#auto_di::__private::lifecycle_result(
value.#method().await,
::std::any::type_name::<#output_type>(),
)?;
}
});
let (destroy_function, destroy_value) = if let Some(method) = &options.pre_destroy {
let method = format_ident!("{method}");
(
quote! {
#[doc(hidden)]
fn #destroy_name(value: #auto_di::DynArc) -> #auto_di::BoxFuture<'static, ::std::result::Result<(), #auto_di::DiError>> {
::std::boxed::Box::pin(async move {
let value = value.downcast::<#output_type>()
.map_err(|_| #auto_di::DiError::TypeMismatch(::std::any::type_name::<#output_type>()))?;
#auto_di::__private::lifecycle_result(
value.#method().await,
::std::any::type_name::<#output_type>(),
)?;
Ok(())
})
}
},
quote!(Some(#destroy_name as _)),
)
} else {
(quote! {}, quote!(None))
};
let resolve_dependencies = argument_names
.iter()
.zip(dependency_types.iter())
.zip(qualifiers.iter())
.map(|((name, ty), qualifier)| dependency_resolution(name, ty, qualifier.as_deref()));
quote! {
#[doc(hidden)]
fn #type_id_name() -> ::std::any::TypeId {
::std::any::TypeId::of::<#output_type>()
}
#[doc(hidden)]
fn #type_name_name() -> &'static str {
::std::any::type_name::<#output_type>()
}
#[doc(hidden)]
fn #factory_name<'a>(
container: &'a #auto_di::Container,
context: #auto_di::ResolutionContext,
) -> #auto_di::BoxFuture<'a, ::std::result::Result<#auto_di::DynArc, #auto_di::DiError>> {
::std::boxed::Box::pin(async move {
#prelude
#(#resolve_dependencies)*
let value: #output_type = #invocation;
#post_construct
Ok(::std::sync::Arc::new(value) as #auto_di::DynArc)
})
}
#destroy_function
#auto_di::__private::inventory::submit! {
#auto_di::ProviderDescriptor::configured(
#type_id_name,
#type_name_name,
#factory_name,
#provider_name,
#primary,
#scope,
#eager,
#profile,
#condition_key,
#condition_value,
#destroy_value,
)
}
}
}
fn singleton_type_ident(ty: &Type) -> syn::Result<&syn::Ident> {
let Type::Path(path) = ty else {
return Err(syn::Error::new_spanned(
ty,
"singleton impl type must be a named type",
));
};
path.path
.segments
.last()
.map(|segment| &segment.ident)
.ok_or_else(|| syn::Error::new_spanned(ty, "singleton impl type must be a named type"))
}
fn validate_dependency_type(ty: &Type) -> syn::Result<()> {
if generic_inner(ty, "Arc").is_some()
|| generic_inner(ty, "Provider").is_some()
|| generic_inner(ty, "Lazy").is_some()
{
return Ok(());
}
if let Some(inner) = generic_inner(ty, "Option").or_else(|| generic_inner(ty, "Vec")) {
if generic_inner(inner, "Arc").is_some() {
return Ok(());
}
}
Err(syn::Error::new_spanned(
ty,
"dependency must be Arc<T>, Option<Arc<T>>, Vec<Arc<T>>, Provider<T>, or Lazy<T>",
))
}
fn generic_inner<'a>(ty: &'a Type, expected: &str) -> Option<&'a Type> {
let Type::Path(path) = ty else { return None };
let segment = path.path.segments.last()?;
if segment.ident != expected {
return None;
}
let PathArguments::AngleBracketed(arguments) = &segment.arguments else {
return None;
};
match arguments.args.first() {
Some(GenericArgument::Type(inner)) if arguments.args.len() == 1 => Some(inner),
_ => None,
}
}
fn dependency_resolution(
name: &syn::Ident,
ty: &Type,
qualifier: Option<&str>,
) -> proc_macro2::TokenStream {
let auto_di = auto_di_path();
if let Some(inner) = generic_inner(ty, "Arc") {
if matches!(inner, Type::TraitObject(_)) {
if let Some(qualifier) = qualifier {
return quote! {
let #name: #ty = (*container.resolve_named_dependency::<#ty>(#qualifier, &context).await?).clone();
};
}
return quote! {
let #name: #ty = (*container.resolve_dependency::<#ty>(&context).await?).clone();
};
}
if let Some(qualifier) = qualifier {
return quote! {
let #name: #ty = container.resolve_named_dependency::<#inner>(#qualifier, &context).await?;
};
}
return quote! {
let #name: #ty = container.resolve_dependency::<#inner>(&context).await?;
};
}
if let Some(wrapped) = generic_inner(ty, "Option") {
let inner = generic_inner(wrapped, "Arc").expect("validated Option<Arc<T>>");
if matches!(inner, Type::TraitObject(_)) {
return quote! {
let #name: #ty = container
.resolve_optional_dependency::<#wrapped>(&context)
.await?
.map(|value| (*value).clone());
};
}
return quote! {
let #name: #ty = container.resolve_optional_dependency::<#inner>(&context).await?;
};
}
if let Some(wrapped) = generic_inner(ty, "Vec") {
let inner = generic_inner(wrapped, "Arc").expect("validated Vec<Arc<T>>");
if matches!(inner, Type::TraitObject(_)) {
return quote! {
let #name: #ty = container
.resolve_all_dependency::<#wrapped>(&context)
.await?
.into_iter()
.map(|value| (*value).clone())
.collect();
};
}
return quote! {
let #name: #ty = container.resolve_all_dependency::<#inner>(&context).await?;
};
}
if generic_inner(ty, "Provider").is_some() {
return quote! {
let #name: #ty = #auto_di::Provider::from_context(container.clone(), context.clone());
};
}
quote! {
let #name: #ty = #auto_di::Lazy::from_context(container.clone(), context.clone());
}
}