use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{
Attribute, Expr, Fields, FnArg, GenericArgument, ImplItem, ImplItemFn, Item, ItemFn, ItemImpl,
ItemStruct, LitStr, Pat, PathArguments, ReturnType, Signature, Type, parse_macro_input,
visit::Visit,
};
fn auto_di_path() -> proc_macro2::TokenStream {
match proc_macro_crate::crate_name("auto-di") {
Ok(proc_macro_crate::FoundCrate::Itself) => quote!(::auto_di),
Ok(proc_macro_crate::FoundCrate::Name(name)) => {
let name = format_ident!("{name}");
quote!(::#name)
}
Err(_) => quote!(::auto_di),
}
}
#[derive(Default, Clone)]
struct ProviderOptions {
name: Option<String>,
primary: bool,
scope: Option<String>,
eager: bool,
blocking: bool,
profile: Option<String>,
condition_key: Option<String>,
condition_value: Option<String>,
post_construct: Option<String>,
pre_destroy: Option<String>,
}
pub(crate) fn singleton(attribute: TokenStream, item: TokenStream) -> TokenStream {
let options = match parse_provider_options(attribute) {
Ok(options) => options,
Err(error) => return error.to_compile_error().into(),
};
let item = parse_macro_input!(item as Item);
let expanded = match item {
Item::Fn(function) => expand_function(function, &options),
Item::Impl(item_impl) => expand_impl(item_impl, &options),
other => Err(syn::Error::new_spanned(
other,
"#[singleton] can only be used on a function or an inherent impl block",
)),
};
match expanded {
Ok(tokens) => tokens.into(),
Err(error) => error.to_compile_error().into(),
}
}
fn parse_provider_options(attribute: TokenStream) -> syn::Result<ProviderOptions> {
parse_provider_options2(attribute.into())
}
fn parse_provider_options2(attribute: proc_macro2::TokenStream) -> syn::Result<ProviderOptions> {
use syn::parse::Parser;
let mut options = ProviderOptions::default();
let parser = syn::meta::parser(|meta| {
if meta.path.is_ident("primary") {
options.primary = true;
return Ok(());
}
if meta.path.is_ident("eager") {
options.eager = true;
return Ok(());
}
if meta.path.is_ident("blocking") {
options.blocking = true;
return Ok(());
}
let value = meta.value()?.parse::<LitStr>()?.value();
if meta.path.is_ident("name") {
options.name = Some(value);
} else if meta.path.is_ident("scope") {
options.scope = Some(value);
} else if meta.path.is_ident("profile") {
options.profile = Some(value);
} else if meta.path.is_ident("condition") {
let (key, expected) = value
.split_once('=')
.map_or((value.as_str(), None), |(k, v)| (k, Some(v)));
options.condition_key = Some(key.to_owned());
options.condition_value = expected.map(str::to_owned);
} else if meta.path.is_ident("post_construct") {
options.post_construct = Some(value);
} else if meta.path.is_ident("pre_destroy") {
options.pre_destroy = Some(value);
} else {
return Err(meta.error("unknown provider option"));
}
Ok(())
});
parser.parse2(attribute)?;
if let Some(scope) = &options.scope {
if !matches!(scope.as_str(), "singleton" | "prototype" | "request") {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"scope must be singleton, prototype, or request",
));
}
}
Ok(options)
}
pub(crate) fn qualifier(_attribute: TokenStream, item: TokenStream) -> TokenStream {
item
}
pub(crate) fn application(_attribute: TokenStream, item: TokenStream) -> TokenStream {
let auto_di = auto_di_path();
let mut function = parse_macro_input!(item as ItemFn);
if function.sig.asyncness.is_none() {
return syn::Error::new_spanned(&function.sig, "#[application] requires async fn")
.to_compile_error()
.into();
}
let body = function.block;
function.sig.asyncness = None;
function.block = Box::new(syn::parse_quote!({
let __runtime = #auto_di::__private::tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("auto-di failed to create the Tokio runtime");
__runtime.block_on(async move {
let __container = #auto_di::global_container()?;
__container.validate().await?;
let __result = (async move #body).await;
__container.shutdown().await?;
__result
})
}));
quote!(#function).into()
}
pub(crate) fn configuration_properties(attribute: TokenStream, item: TokenStream) -> TokenStream {
let prefix = parse_macro_input!(attribute as LitStr).value();
let structure = parse_macro_input!(item as ItemStruct);
match expand_configuration_properties(prefix, structure) {
Ok(tokens) => tokens.into(),
Err(error) => error.to_compile_error().into(),
}
}
fn expand_configuration_properties(
prefix: String,
structure: ItemStruct,
) -> syn::Result<proc_macro2::TokenStream> {
let auto_di = auto_di_path();
let ident = &structure.ident;
let fields = match &structure.fields {
syn::Fields::Named(fields) => &fields.named,
_ => {
return Err(syn::Error::new_spanned(
&structure,
"configuration properties require named fields",
));
}
};
let bindings = fields.iter().map(|field| {
let name = field.ident.as_ref().expect("named field");
let ty = &field.ty;
let key = format!("{}_{}", prefix, name).replace(['.', '-'], "_").to_uppercase();
quote! {
#name: ::std::env::var(#key)
.map_err(|error| #auto_di::DiError::Configuration { key: #key.into(), message: error.to_string() })?
.parse::<#ty>()
.map_err(|_| #auto_di::DiError::Configuration { key: #key.into(), message: "value could not be parsed".into() })?
}
});
let provider_name = format_ident!("configuration_properties_{ident}");
let invocation = quote! { <#ident as #auto_di::ConfigurationProperties>::from_environment()? };
let registration = registration(
&provider_name,
&syn::parse_quote!(#ident),
&[],
&[],
&[],
invocation,
);
Ok(quote! {
#structure
impl #auto_di::ConfigurationProperties for #ident {
fn from_environment() -> ::std::result::Result<Self, #auto_di::DiError> {
Ok(Self { #(#bindings),* })
}
}
#registration
})
}
pub(crate) fn configuration(_attribute: TokenStream, item: TokenStream) -> TokenStream {
let item_impl = parse_macro_input!(item as ItemImpl);
match expand_configuration(item_impl) {
Ok(tokens) => tokens.into(),
Err(error) => error.to_compile_error().into(),
}
}
pub(crate) fn provider(attribute: TokenStream, item: TokenStream) -> TokenStream {
singleton(attribute, item)
}
pub(crate) fn injectable(_attribute: TokenStream, item: TokenStream) -> TokenStream {
let structure = parse_macro_input!(item as ItemStruct);
match expand_injectable(structure) {
Ok(tokens) => tokens.into(),
Err(error) => error.to_compile_error().into(),
}
}
pub(crate) fn injected(_attribute: TokenStream, item: TokenStream) -> TokenStream {
let function = parse_macro_input!(item as ItemFn);
match expand_injected(function) {
Ok(tokens) => tokens.into(),
Err(error) => error.to_compile_error().into(),
}
}
enum InjectOverride {
Automatic,
Expression(Expr),
}
fn take_inject(attributes: &mut Vec<Attribute>) -> syn::Result<Option<InjectOverride>> {
let Some(position) = attributes
.iter()
.position(|attribute| attribute.path().is_ident("inject"))
else {
return Ok(None);
};
let attribute = attributes.remove(position);
match attribute.meta {
syn::Meta::Path(_) => Ok(Some(InjectOverride::Automatic)),
syn::Meta::List(list) => Ok(Some(InjectOverride::Expression(syn::parse2(list.tokens)?))),
syn::Meta::NameValue(_) => Err(syn::Error::new_spanned(
attribute,
"use #[inject] or #[inject(expression)]",
)),
}
}
fn take_argument(attributes: &mut Vec<Attribute>) -> bool {
let Some(position) = attributes
.iter()
.position(|attribute| attribute.path().is_ident("argument"))
else {
return false;
};
attributes.remove(position);
true
}
fn expand_injectable(mut structure: ItemStruct) -> syn::Result<proc_macro2::TokenStream> {
if !structure.generics.params.is_empty() {
return Err(syn::Error::new_spanned(
&structure.generics,
"generic injectable structs are not supported",
));
}
let auto_di = auto_di_path();
let ident = structure.ident.clone();
let mut parameters = Vec::new();
let mut values = Vec::new();
for (index, field) in structure.fields.iter_mut().enumerate() {
let parameter = format_ident!("__auto_di_field_{index}");
let override_value = take_inject(&mut field.attrs)?;
let (parameter_tokens, value) = injectable_field(¶meter, &field.ty, override_value)?;
if let Some(parameter_tokens) = parameter_tokens {
parameters.push(parameter_tokens);
}
values.push(value);
}
let construction = match &structure.fields {
Fields::Named(fields) => {
let names = fields
.named
.iter()
.map(|field| field.ident.as_ref().unwrap());
quote!(Self { #(#names: #values),* })
}
Fields::Unnamed(_) => quote!(Self(#(#values),*)),
Fields::Unit => quote!(Self),
};
Ok(quote! {
#structure
#[#auto_di::singleton]
impl #ident {
fn new(#(#parameters),*) -> Self {
#construction
}
}
})
}
fn injectable_field(
parameter: &syn::Ident,
field_type: &Type,
override_value: Option<InjectOverride>,
) -> syn::Result<(Option<proc_macro2::TokenStream>, proc_macro2::TokenStream)> {
match override_value {
Some(InjectOverride::Expression(Expr::Closure(closure))) => {
if closure.inputs.len() != 1 {
return Err(syn::Error::new_spanned(
closure,
"an inject closure must take exactly one typed dependency",
));
}
let Pat::Type(input) = closure.inputs.first().unwrap() else {
return Err(syn::Error::new_spanned(
&closure.inputs,
"inject closure input must declare its type",
));
};
let input_type = input.ty.as_ref();
if generic_inner(input_type, "Arc").is_some() {
Ok((
Some(quote!(#parameter: #input_type)),
quote!((#closure)(#parameter)),
))
} else {
Ok((
Some(quote!(#parameter: ::std::sync::Arc<#input_type>)),
quote!((#closure)((*#parameter).clone())),
))
}
}
Some(InjectOverride::Expression(expression)) => Ok((None, quote!(#expression))),
Some(InjectOverride::Automatic) | None => {
if validate_dependency_type(field_type).is_ok() {
Ok((Some(quote!(#parameter: #field_type)), quote!(#parameter)))
} else {
Ok((
Some(quote!(#parameter: ::std::sync::Arc<#field_type>)),
quote!((*#parameter).clone()),
))
}
}
}
}
fn expand_injected(mut implementation: ItemFn) -> syn::Result<proc_macro2::TokenStream> {
if !implementation.sig.generics.params.is_empty() {
return Err(syn::Error::new_spanned(
&implementation.sig.generics,
"generic injected functions are not supported",
));
}
let auto_di = auto_di_path();
let original_name = implementation.sig.ident.clone();
let implementation_name = format_ident!("__auto_di_injected_{original_name}");
let visibility = implementation.vis.clone();
let original_output = implementation.sig.output.clone();
let output_type: Type = match &original_output {
ReturnType::Default => syn::parse_quote!(()),
ReturnType::Type(_, ty) => ty.as_ref().clone(),
};
let was_async = implementation.sig.asyncness.is_some();
let mut wrapper_inputs = syn::punctuated::Punctuated::new();
let mut call_arguments = Vec::new();
let mut resolutions = Vec::new();
let mut has_receiver = false;
for input in &mut implementation.sig.inputs {
match input {
FnArg::Receiver(receiver) => {
has_receiver = true;
wrapper_inputs.push(FnArg::Receiver(receiver.clone()));
}
FnArg::Typed(argument) => {
let Pat::Ident(pattern) = argument.pat.as_ref() else {
return Err(syn::Error::new_spanned(
&argument.pat,
"injected functions require simple parameter names",
));
};
let name = pattern.ident.clone();
let argument_type = argument.ty.as_ref().clone();
let inject = take_inject(&mut argument.attrs)?;
let force_argument = take_argument(&mut argument.attrs);
let qualifier = take_qualifier(&mut argument.attrs)?;
if force_argument && inject.is_some() {
return Err(syn::Error::new_spanned(
argument,
"a parameter cannot use both #[inject] and #[argument]",
));
}
let inferred_injection = validate_dependency_type(&argument_type).is_ok();
if !force_argument && (inject.is_some() || inferred_injection) {
let inject = inject.unwrap_or(InjectOverride::Automatic);
let value = global_injected_value(argument.ty.as_ref(), inject, qualifier)?;
resolutions.push(quote!(let #name: #argument_type = #value;));
} else {
if qualifier.is_some() {
return Err(syn::Error::new_spanned(
argument,
"qualifier requires #[inject] on an #[injected] parameter",
));
}
wrapper_inputs.push(FnArg::Typed(argument.clone()));
}
call_arguments.push(name);
}
}
}
if resolutions.is_empty() {
return Err(syn::Error::new_spanned(
&implementation.sig,
"#[injected] requires an injectable parameter or #[inject(expression)]",
));
}
implementation.sig.ident = implementation_name.clone();
implementation.vis = syn::Visibility::Inherited;
let call = if has_receiver {
quote!(self.#implementation_name(#(#call_arguments),*))
} else {
quote!(#implementation_name(#(#call_arguments),*))
};
let call = if was_async { quote!(#call.await) } else { call };
let mut wrapper = implementation.clone();
wrapper.sig.ident = original_name;
wrapper.sig.inputs = wrapper_inputs;
wrapper.sig.asyncness = Some(syn::token::Async::default());
wrapper.sig.output =
syn::parse_quote!(-> ::std::result::Result<#output_type, #auto_di::DiError>);
wrapper.vis = visibility;
wrapper.block = Box::new(syn::parse_quote!({
#(#resolutions)*
Ok(#call)
}));
Ok(quote! {
#implementation
#wrapper
})
}
fn global_injected_value(
ty: &Type,
inject: InjectOverride,
qualifier: Option<String>,
) -> syn::Result<proc_macro2::TokenStream> {
match inject {
InjectOverride::Automatic => global_auto_value(ty, qualifier.as_deref()),
InjectOverride::Expression(Expr::Closure(closure)) => {
if closure.inputs.len() != 1 {
return Err(syn::Error::new_spanned(
closure,
"an inject closure must take exactly one typed dependency",
));
}
let Pat::Type(input) = closure.inputs.first().unwrap() else {
return Err(syn::Error::new_spanned(
&closure.inputs,
"inject closure input must declare its type",
));
};
let dependency = global_auto_value(input.ty.as_ref(), qualifier.as_deref())?;
Ok(quote!((#closure)(#dependency)))
}
InjectOverride::Expression(expression) => {
if qualifier.is_some() {
return Err(syn::Error::new_spanned(
expression,
"qualifier cannot be combined with a literal inject expression",
));
}
Ok(quote!(#expression))
}
}
}
fn global_auto_value(ty: &Type, qualifier: Option<&str>) -> syn::Result<proc_macro2::TokenStream> {
let auto_di = auto_di_path();
if let Some(inner) = generic_inner(ty, "Arc") {
if matches!(inner, Type::TraitObject(_)) {
let resolution = if let Some(qualifier) = qualifier {
quote!((*#auto_di::global_container()?.resolve_named::<#ty>(#qualifier).await?).clone())
} else {
quote!((*#auto_di::resolve::<#ty>().await?).clone())
};
return Ok(resolution);
}
let resolution = if let Some(qualifier) = qualifier {
quote!(#auto_di::global_container()?.resolve_named::<#inner>(#qualifier).await?)
} else {
quote!(#auto_di::resolve::<#inner>().await?)
};
return Ok(resolution);
}
if qualifier.is_some() {
return Err(syn::Error::new_spanned(
ty,
"qualifier is currently supported on Arc<T> injected parameters",
));
}
if let Some(wrapped) = generic_inner(ty, "Option") {
let inner = generic_inner(wrapped, "Arc").ok_or_else(|| {
syn::Error::new_spanned(ty, "optional injection must be Option<Arc<T>>")
})?;
return Ok(quote!(#auto_di::global_container()?.resolve_optional::<#inner>().await?));
}
if let Some(wrapped) = generic_inner(ty, "Vec") {
let inner = generic_inner(wrapped, "Arc").ok_or_else(|| {
syn::Error::new_spanned(ty, "collection injection must be Vec<Arc<T>>")
})?;
return Ok(quote!(#auto_di::global_container()?.resolve_all::<#inner>().await?));
}
if generic_inner(ty, "Provider").is_some() {
return Ok(quote!(#auto_di::Provider::from_context(
#auto_di::global_container()?.clone(),
#auto_di::ResolutionContext::default(),
)));
}
if generic_inner(ty, "Lazy").is_some() {
return Ok(quote!(#auto_di::Lazy::from_context(
#auto_di::global_container()?.clone(),
#auto_di::ResolutionContext::default(),
)));
}
Ok(quote!((*#auto_di::resolve::<#ty>().await?).clone()))
}
fn expand_function(
mut function: ItemFn,
options: &ProviderOptions,
) -> syn::Result<proc_macro2::TokenStream> {
let function_name = function.sig.ident.clone();
let (output_type, fallible) = provided_output(&function.sig.output)?;
let (argument_names, dependency_types, qualifiers) = dependencies(&mut function.sig)?;
let is_async = function.sig.asyncness.is_some();
let invocation = if is_async {
quote! { #function_name(#(#argument_names),*).await }
} else {
quote! { #function_name(#(#argument_names),*) }
};
let invocation = blocking_invocation(invocation, options.blocking, is_async, &output_type)?;
let invocation = factory_invocation(invocation, fallible, &output_type);
let registration = registration_with_options(
&function_name,
&output_type,
&argument_names,
&dependency_types,
&qualifiers,
invocation,
options,
);
Ok(quote! {
#function
#registration
})
}
fn expand_impl(
mut item_impl: ItemImpl,
options: &ProviderOptions,
) -> syn::Result<proc_macro2::TokenStream> {
if item_impl.trait_.is_some() {
return Err(syn::Error::new_spanned(
&item_impl,
"#[singleton] requires an inherent impl, not a trait impl",
));
}
if !item_impl.generics.params.is_empty() {
return Err(syn::Error::new_spanned(
&item_impl.generics,
"generic singleton impl blocks are not supported yet",
));
}
reject_direct_provider_calls(&item_impl)?;
let constructor_index = item_impl
.items
.iter()
.position(|item| match item {
ImplItem::Fn(function) => function.sig.ident == "new",
_ => false,
})
.ok_or_else(|| {
syn::Error::new_spanned(
&item_impl.self_ty,
"a #[singleton] impl must contain a new(...) -> Self constructor",
)
})?;
let output_type = item_impl.self_ty.as_ref().clone();
let constructor = match &mut item_impl.items[constructor_index] {
ImplItem::Fn(function) => function,
_ => unreachable!(),
};
let fallible_constructor = validate_impl_constructor(constructor)?;
let (argument_names, dependency_types, qualifiers) = dependencies(&mut constructor.sig)?;
let type_ident = singleton_type_ident(&output_type)?;
let constructor_name = constructor.sig.ident.clone();
let is_async = constructor.sig.asyncness.is_some();
let invocation = if is_async {
quote! { <#output_type>::#constructor_name(#(#argument_names),*).await }
} else {
quote! { <#output_type>::#constructor_name(#(#argument_names),*) }
};
let invocation = blocking_invocation(invocation, options.blocking, is_async, &output_type)?;
let invocation = factory_invocation(invocation, fallible_constructor, &output_type);
let registration = registration_with_options(
type_ident,
&output_type,
&argument_names,
&dependency_types,
&qualifiers,
invocation,
options,
);
let mut provider_registrations = Vec::new();
for item in &mut item_impl.items {
let ImplItem::Fn(method) = item else { continue };
let Some(provider_attribute) = method
.attrs
.iter()
.find(|attribute| is_provider_attribute(attribute))
else {
continue;
};
let provider_options = match &provider_attribute.meta {
syn::Meta::Path(_) => ProviderOptions::default(),
syn::Meta::List(list) => parse_provider_options2(list.tokens.clone())?,
syn::Meta::NameValue(_) => {
return Err(syn::Error::new_spanned(
provider_attribute,
"use #[provider(...)]",
));
}
};
method
.attrs
.retain(|attribute| !is_provider_attribute(attribute));
let (provider_type, fallible) = provided_output(&method.sig.output)?;
let (has_receiver, provider_arguments, provider_dependencies, provider_qualifiers) =
provider_dependencies(&mut method.sig)?;
let method_name = method.sig.ident.clone();
let provider_registration_name = format_ident!("provider_{type_ident}_{method_name}");
let provider_invocation = if has_receiver {
quote! { owner.#method_name(#(#provider_arguments),*) }
} else {
quote! { <#output_type>::#method_name(#(#provider_arguments),*) }
};
let is_async = method.sig.asyncness.is_some();
let provider_invocation = if is_async {
quote! { #provider_invocation.await }
} else {
provider_invocation
};
let provider_invocation = blocking_invocation(
provider_invocation,
provider_options.blocking,
is_async,
&provider_type,
)?;
let provider_invocation = factory_invocation(provider_invocation, fallible, &provider_type);
let prelude = has_receiver.then(|| {
quote! {
let owner: ::std::sync::Arc<#output_type> =
container.resolve_dependency::<#output_type>(&context).await?;
}
});
provider_registrations.push(registration_with_prelude(
&provider_registration_name,
&provider_type,
&provider_arguments,
&provider_dependencies,
&provider_qualifiers,
prelude.unwrap_or_default(),
provider_invocation,
&provider_options,
));
}
Ok(quote! {
#item_impl
#registration
#(#provider_registrations)*
})
}
fn reject_direct_provider_calls(item_impl: &ItemImpl) -> syn::Result<()> {
let provider_names = item_impl
.items
.iter()
.filter_map(|item| match item {
ImplItem::Fn(method) if method.attrs.iter().any(is_provider_attribute) => {
Some(method.sig.ident.clone())
}
_ => None,
})
.collect::<std::collections::HashSet<_>>();
if provider_names.is_empty() {
return Ok(());
}
for item in &item_impl.items {
if let ImplItem::Fn(method) = item {
let mut visitor = DirectProviderCallVisitor {
provider_names: &provider_names,
error: None,
};
visitor.visit_block(&method.block);
if let Some(error) = visitor.error {
return Err(error);
}
}
}
Ok(())
}
struct DirectProviderCallVisitor<'a> {
provider_names: &'a std::collections::HashSet<syn::Ident>,
error: Option<syn::Error>,
}
impl<'ast> Visit<'ast> for DirectProviderCallVisitor<'_> {
fn visit_expr_method_call(&mut self, expression: &'ast syn::ExprMethodCall) {
if self.provider_names.contains(&expression.method)
&& matches!(expression.receiver.as_ref(), syn::Expr::Path(path) if path.path.is_ident("self"))
{
self.error = Some(syn::Error::new_spanned(
expression,
"direct provider calls bypass dependency injection; inject the provided type as a parameter",
));
return;
}
syn::visit::visit_expr_method_call(self, expression);
}
fn visit_expr_call(&mut self, expression: &'ast syn::ExprCall) {
if let syn::Expr::Path(path) = expression.func.as_ref()
&& path.path.segments.len() > 1
&& path
.path
.segments
.last()
.is_some_and(|segment| self.provider_names.contains(&segment.ident))
{
self.error = Some(syn::Error::new_spanned(
expression,
"direct provider calls bypass dependency injection; inject the provided type as a parameter",
));
return;
}
syn::visit::visit_expr_call(self, expression);
}
}
fn expand_configuration(mut item_impl: ItemImpl) -> syn::Result<proc_macro2::TokenStream> {
if item_impl.trait_.is_some() || !item_impl.generics.params.is_empty() {
return Err(syn::Error::new_spanned(
&item_impl,
"#[configuration] requires a non-generic inherent impl",
));
}
reject_direct_provider_calls(&item_impl)?;
let configuration_type = item_impl.self_ty.as_ref().clone();
let configuration_ident = singleton_type_ident(&configuration_type)?.clone();
let configuration_registration = registration(
&format_ident!("configuration_{configuration_ident}"),
&configuration_type,
&[],
&[],
&[],
quote! { <#configuration_type as ::std::default::Default>::default() },
);
let mut registrations = Vec::new();
for item in &mut item_impl.items {
let ImplItem::Fn(method) = item else { continue };
let Some(provider_attribute) = method
.attrs
.iter()
.find(|attribute| is_provider_attribute(attribute))
else {
continue;
};
let provider_options = match &provider_attribute.meta {
syn::Meta::Path(_) => ProviderOptions::default(),
syn::Meta::List(list) => parse_provider_options2(list.tokens.clone())?,
syn::Meta::NameValue(_) => {
return Err(syn::Error::new_spanned(
provider_attribute,
"use #[provider(...)]",
));
}
};
method
.attrs
.retain(|attribute| !is_provider_attribute(attribute));
let (output_type, fallible) = provided_output(&method.sig.output)?;
let (has_receiver, argument_names, dependency_types, qualifiers) =
provider_dependencies(&mut method.sig)?;
let method_name = method.sig.ident.clone();
let registration_name = format_ident!("provider_{configuration_ident}_{method_name}");
let invocation = if has_receiver {
quote! { configuration.#method_name(#(#argument_names),*) }
} else {
quote! { <#configuration_type>::#method_name(#(#argument_names),*) }
};
let is_async = method.sig.asyncness.is_some();
let invocation = if is_async {
quote! { #invocation.await }
} else {
invocation
};
let invocation = blocking_invocation(
invocation,
provider_options.blocking,
is_async,
&output_type,
)?;
let invocation = factory_invocation(invocation, fallible, &output_type);
let prelude = has_receiver.then(|| {
quote! {
let configuration: ::std::sync::Arc<#configuration_type> =
container.resolve_dependency::<#configuration_type>(&context).await?;
}
});
registrations.push(registration_with_prelude(
®istration_name,
&output_type,
&argument_names,
&dependency_types,
&qualifiers,
prelude.unwrap_or_default(),
invocation,
&provider_options,
));
}
if registrations.is_empty() {
return Err(syn::Error::new_spanned(
&item_impl,
"#[configuration] must contain at least one #[provider] method",
));
}
Ok(quote! {
#item_impl
#configuration_registration
#(#registrations)*
})
}
fn validate_impl_constructor(constructor: &ImplItemFn) -> syn::Result<bool> {
if constructor.sig.receiver().is_some() {
return Err(syn::Error::new_spanned(
&constructor.sig,
"the singleton new constructor must be an associated function",
));
}
match &constructor.sig.output {
ReturnType::Type(_, ty) if is_self_type(ty) => Ok(false),
ReturnType::Type(_, ty) if result_ok_type(ty).is_some_and(is_self_type) => Ok(true),
_ => Err(syn::Error::new_spanned(
&constructor.sig.output,
"the singleton new constructor must return Self or Result<Self, E>",
)),
}
}
fn provided_output(output: &ReturnType) -> syn::Result<(Type, bool)> {
let ReturnType::Type(_, ty) = output else {
return Err(syn::Error::new_spanned(
output,
"a provider must return a value",
));
};
if let Some(inner) = result_ok_type(ty) {
Ok((inner.clone(), true))
} else {
Ok((ty.as_ref().clone(), false))
}
}
fn result_ok_type(ty: &Type) -> Option<&Type> {
let Type::Path(path) = ty else { return None };
let segment = path.path.segments.last()?;
if segment.ident != "Result" {
return None;
}
let PathArguments::AngleBracketed(arguments) = &segment.arguments else {
return None;
};
match arguments.args.first() {
Some(GenericArgument::Type(ok)) => Some(ok),
_ => None,
}
}
fn is_self_type(ty: &Type) -> bool {
matches!(ty, Type::Path(path) if path.path.is_ident("Self"))
}
fn factory_invocation(
invocation: proc_macro2::TokenStream,
fallible: bool,
output_type: &Type,
) -> proc_macro2::TokenStream {
let auto_di = auto_di_path();
if fallible {
quote! {
#auto_di::__private::factory_result(
#invocation,
::std::any::type_name::<#output_type>(),
)?
}
} else {
invocation
}
}
fn blocking_invocation(
invocation: proc_macro2::TokenStream,
blocking: bool,
is_async: bool,
output_type: &Type,
) -> syn::Result<proc_macro2::TokenStream> {
if !blocking {
return Ok(invocation);
}
if is_async {
return Err(syn::Error::new_spanned(
output_type,
"blocking can only be used with synchronous providers",
));
}
let auto_di = auto_di_path();
Ok(quote! {
#auto_di::__private::tokio::task::spawn_blocking(move || #invocation)
.await
.map_err(|error| #auto_di::DiError::Factory {
provider: ::std::any::type_name::<#output_type>(),
message: error.to_string(),
})?
})
}
fn dependencies(
signature: &mut Signature,
) -> syn::Result<(Vec<syn::Ident>, Vec<Type>, Vec<Option<String>>)> {
let mut argument_names = Vec::new();
let mut dependency_types = Vec::new();
let mut qualifiers = Vec::new();
for input in &mut signature.inputs {
let FnArg::Typed(argument) = input else {
return Err(syn::Error::new_spanned(
input,
"singleton constructors cannot have a self receiver",
));
};
let Pat::Ident(pattern) = argument.pat.as_ref() else {
return Err(syn::Error::new_spanned(
&argument.pat,
"use a simple parameter name for an injected dependency",
));
};
argument_names.push(pattern.ident.clone());
validate_dependency_type(argument.ty.as_ref())?;
dependency_types.push(argument.ty.as_ref().clone());
qualifiers.push(take_qualifier(&mut argument.attrs)?);
}
Ok((argument_names, dependency_types, qualifiers))
}
fn provider_dependencies(
signature: &mut Signature,
) -> syn::Result<(bool, Vec<syn::Ident>, Vec<Type>, Vec<Option<String>>)> {
let mut has_receiver = false;
let mut names = Vec::new();
let mut types = Vec::new();
let mut qualifiers = Vec::new();
for input in &mut signature.inputs {
match input {
FnArg::Receiver(receiver) => {
if has_receiver || receiver.reference.is_none() || receiver.mutability.is_some() {
return Err(syn::Error::new_spanned(
receiver,
"a #[provider] method only supports an immutable &self receiver",
));
}
has_receiver = true;
}
FnArg::Typed(argument) => {
let Pat::Ident(pattern) = argument.pat.as_ref() else {
return Err(syn::Error::new_spanned(
&argument.pat,
"use a simple parameter name for an injected dependency",
));
};
names.push(pattern.ident.clone());
validate_dependency_type(argument.ty.as_ref())?;
types.push(argument.ty.as_ref().clone());
qualifiers.push(take_qualifier(&mut argument.attrs)?);
}
}
}
Ok((has_receiver, names, types, qualifiers))
}
fn is_provider_attribute(attribute: &syn::Attribute) -> bool {
attribute.path().is_ident("provider")
}
fn take_qualifier(attributes: &mut Vec<syn::Attribute>) -> syn::Result<Option<String>> {
let mut value = None;
for attribute in attributes
.iter()
.filter(|attr| attr.path().is_ident("qualifier"))
{
if value.is_some() {
return Err(syn::Error::new_spanned(
attribute,
"only one qualifier is allowed",
));
}
value = Some(attribute.parse_args::<LitStr>()?.value());
}
attributes.retain(|attr| !attr.path().is_ident("qualifier"));
Ok(value)
}
fn registration(
name: &syn::Ident,
output_type: &Type,
argument_names: &[syn::Ident],
dependency_types: &[Type],
qualifiers: &[Option<String>],
invocation: proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
registration_with_options(
name,
output_type,
argument_names,
dependency_types,
qualifiers,
invocation,
&ProviderOptions::default(),
)
}
fn registration_with_options(
name: &syn::Ident,
output_type: &Type,
argument_names: &[syn::Ident],
dependency_types: &[Type],
qualifiers: &[Option<String>],
invocation: proc_macro2::TokenStream,
options: &ProviderOptions,
) -> proc_macro2::TokenStream {
registration_with_prelude(
name,
output_type,
argument_names,
dependency_types,
qualifiers,
quote! {},
invocation,
options,
)
}
fn registration_with_prelude(
name: &syn::Ident,
output_type: &Type,
argument_names: &[syn::Ident],
dependency_types: &[Type],
qualifiers: &[Option<String>],
prelude: proc_macro2::TokenStream,
invocation: proc_macro2::TokenStream,
options: &ProviderOptions,
) -> proc_macro2::TokenStream {
let auto_di = auto_di_path();
let factory_name = format_ident!("__di_factory_{name}");
let type_id_name = format_ident!("__di_type_id_{name}");
let type_name_name = format_ident!("__di_type_name_{name}");
let destroy_name = format_ident!("__di_destroy_{name}");
let option_tokens = |value: Option<&str>| match value {
Some(value) => quote!(Some(#value)),
None => quote!(None),
};
let provider_name = option_tokens(options.name.as_deref());
let primary = options.primary;
let eager = options.eager;
let profile = option_tokens(options.profile.as_deref());
let condition_key = option_tokens(options.condition_key.as_deref());
let condition_value = option_tokens(options.condition_value.as_deref());
let scope = match options.scope.as_deref().unwrap_or("singleton") {
"prototype" => quote!(#auto_di::Scope::Prototype),
"request" => quote!(#auto_di::Scope::Request),
_ => quote!(#auto_di::Scope::Singleton),
};
let post_construct = options.post_construct.as_ref().map(|method| {
let method = format_ident!("{method}");
quote! {
#auto_di::__private::lifecycle_result(
value.#method().await,
::std::any::type_name::<#output_type>(),
)?;
}
});
let (destroy_function, destroy_value) = if let Some(method) = &options.pre_destroy {
let method = format_ident!("{method}");
(
quote! {
#[doc(hidden)]
fn #destroy_name(value: #auto_di::DynArc) -> #auto_di::BoxFuture<'static, ::std::result::Result<(), #auto_di::DiError>> {
::std::boxed::Box::pin(async move {
let value = value.downcast::<#output_type>()
.map_err(|_| #auto_di::DiError::TypeMismatch(::std::any::type_name::<#output_type>()))?;
#auto_di::__private::lifecycle_result(
value.#method().await,
::std::any::type_name::<#output_type>(),
)?;
Ok(())
})
}
},
quote!(Some(#destroy_name as _)),
)
} else {
(quote! {}, quote!(None))
};
let resolve_dependencies = argument_names
.iter()
.zip(dependency_types.iter())
.zip(qualifiers.iter())
.map(|((name, ty), qualifier)| dependency_resolution(name, ty, qualifier.as_deref()));
quote! {
#[doc(hidden)]
fn #type_id_name() -> ::std::any::TypeId {
::std::any::TypeId::of::<#output_type>()
}
#[doc(hidden)]
fn #type_name_name() -> &'static str {
::std::any::type_name::<#output_type>()
}
#[doc(hidden)]
fn #factory_name<'a>(
container: &'a #auto_di::Container,
context: #auto_di::ResolutionContext,
) -> #auto_di::BoxFuture<'a, ::std::result::Result<#auto_di::DynArc, #auto_di::DiError>> {
::std::boxed::Box::pin(async move {
#prelude
#(#resolve_dependencies)*
let value: #output_type = #invocation;
#post_construct
Ok(::std::sync::Arc::new(value) as #auto_di::DynArc)
})
}
#destroy_function
#auto_di::__private::inventory::submit! {
#auto_di::ProviderDescriptor::configured(
#type_id_name,
#type_name_name,
#factory_name,
#provider_name,
#primary,
#scope,
#eager,
#profile,
#condition_key,
#condition_value,
#destroy_value,
)
}
}
}
fn singleton_type_ident(ty: &Type) -> syn::Result<&syn::Ident> {
let Type::Path(path) = ty else {
return Err(syn::Error::new_spanned(
ty,
"singleton impl type must be a named type",
));
};
path.path
.segments
.last()
.map(|segment| &segment.ident)
.ok_or_else(|| syn::Error::new_spanned(ty, "singleton impl type must be a named type"))
}
fn validate_dependency_type(ty: &Type) -> syn::Result<()> {
if generic_inner(ty, "Arc").is_some()
|| generic_inner(ty, "Provider").is_some()
|| generic_inner(ty, "Lazy").is_some()
{
return Ok(());
}
if let Some(inner) = generic_inner(ty, "Option").or_else(|| generic_inner(ty, "Vec")) {
if generic_inner(inner, "Arc").is_some() {
return Ok(());
}
}
Err(syn::Error::new_spanned(
ty,
"dependency must be Arc<T>, Option<Arc<T>>, Vec<Arc<T>>, Provider<T>, or Lazy<T>",
))
}
fn generic_inner<'a>(ty: &'a Type, expected: &str) -> Option<&'a Type> {
let Type::Path(path) = ty else { return None };
let segment = path.path.segments.last()?;
if segment.ident != expected {
return None;
}
let PathArguments::AngleBracketed(arguments) = &segment.arguments else {
return None;
};
match arguments.args.first() {
Some(GenericArgument::Type(inner)) if arguments.args.len() == 1 => Some(inner),
_ => None,
}
}
fn dependency_resolution(
name: &syn::Ident,
ty: &Type,
qualifier: Option<&str>,
) -> proc_macro2::TokenStream {
let auto_di = auto_di_path();
if let Some(inner) = generic_inner(ty, "Arc") {
if matches!(inner, Type::TraitObject(_)) {
if let Some(qualifier) = qualifier {
return quote! {
let #name: #ty = (*container.resolve_named_dependency::<#ty>(#qualifier, &context).await?).clone();
};
}
return quote! {
let #name: #ty = (*container.resolve_dependency::<#ty>(&context).await?).clone();
};
}
if let Some(qualifier) = qualifier {
return quote! {
let #name: #ty = container.resolve_named_dependency::<#inner>(#qualifier, &context).await?;
};
}
return quote! {
let #name: #ty = container.resolve_dependency::<#inner>(&context).await?;
};
}
if let Some(wrapped) = generic_inner(ty, "Option") {
let inner = generic_inner(wrapped, "Arc").expect("validated Option<Arc<T>>");
if matches!(inner, Type::TraitObject(_)) {
return quote! {
let #name: #ty = container
.resolve_optional_dependency::<#wrapped>(&context)
.await?
.map(|value| (*value).clone());
};
}
return quote! {
let #name: #ty = container.resolve_optional_dependency::<#inner>(&context).await?;
};
}
if let Some(wrapped) = generic_inner(ty, "Vec") {
let inner = generic_inner(wrapped, "Arc").expect("validated Vec<Arc<T>>");
if matches!(inner, Type::TraitObject(_)) {
return quote! {
let #name: #ty = container
.resolve_all_dependency::<#wrapped>(&context)
.await?
.into_iter()
.map(|value| (*value).clone())
.collect();
};
}
return quote! {
let #name: #ty = container.resolve_all_dependency::<#inner>(&context).await?;
};
}
if generic_inner(ty, "Provider").is_some() {
return quote! {
let #name: #ty = #auto_di::Provider::from_context(container.clone(), context.clone());
};
}
quote! {
let #name: #ty = #auto_di::Lazy::from_context(container.clone(), context.clone());
}
}