use crate::crate_paths::{
get_reinhardt_core_crate, get_reinhardt_di_crate, get_reinhardt_http_crate,
};
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Attribute, FnArg, ItemFn, Pat, PatType, Result, Type};
fn is_inject_attr(attr: &Attribute) -> bool {
attr.path().is_ident("inject")
}
#[derive(Clone)]
struct ProcessedArg {
pat: Pat,
ty: Type,
inject: bool,
use_cache: bool,
}
impl ProcessedArg {
fn from_fn_arg(arg: &FnArg) -> Option<Self> {
match arg {
FnArg::Typed(PatType { attrs, pat, ty, .. }) => {
let mut inject = false;
let mut use_cache = true;
for attr in attrs {
if is_inject_attr(attr) {
inject = true;
if let Ok(meta) = attr.parse_args::<syn::Meta>()
&& let syn::Meta::NameValue(nv) = meta
&& nv.path.is_ident("cache")
&& let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Bool(lit_bool),
..
}) = &nv.value
{
use_cache = lit_bool.value;
}
}
}
let pat_without_mut = match &**pat {
syn::Pat::Ident(pat_ident) => {
let mut new_pat_ident = pat_ident.clone();
new_pat_ident.mutability = None;
syn::Pat::Ident(new_pat_ident)
}
other => other.clone(),
};
Some(ProcessedArg {
pat: pat_without_mut,
ty: (**ty).clone(),
inject,
use_cache,
})
}
_ => None,
}
}
}
pub(crate) fn use_inject_impl(_args: TokenStream, input: ItemFn) -> Result<TokenStream> {
let di_crate = get_reinhardt_di_crate();
let core_crate = get_reinhardt_core_crate();
let http_crate = get_reinhardt_http_crate();
let ItemFn {
attrs,
vis,
sig,
block,
} = input;
let fn_name = &sig.ident;
let asyncness = &sig.asyncness;
let output = &sig.output;
let generics = &sig.generics;
let where_clause = &sig.generics.where_clause;
if asyncness.is_none() {
return Err(syn::Error::new_spanned(
&sig,
"#[use_inject] can only be used on async functions",
));
}
if matches!(sig.output, syn::ReturnType::Default) {
return Err(syn::Error::new_spanned(
&sig,
"#[use_inject] functions must have an explicit return type",
));
}
let mut processed_args = Vec::new();
let mut self_param: Option<FnArg> = None;
let mut inject_params = Vec::new();
let mut request_param = None;
let mut other_params = Vec::new();
for arg in &sig.inputs {
if let Some(processed) = ProcessedArg::from_fn_arg(arg) {
if processed.inject {
inject_params.push(processed);
} else {
let is_request = if let Type::Path(type_path) = &processed.ty {
type_path
.path
.segments
.last()
.is_some_and(|s| s.ident == "Request")
} else {
false
};
if is_request {
request_param = Some(processed.pat.clone());
} else {
other_params.push(processed.clone());
}
processed_args.push(processed);
}
} else {
self_param = Some(arg.clone());
}
}
let has_request = request_param.is_some();
let request_pat: Pat = request_param.unwrap_or_else(|| {
syn::parse_quote! { __req }
});
let original_fn_name = syn::Ident::new(&format!("{}_original", fn_name), fn_name.span());
let mut original_params: Vec<FnArg> = Vec::new();
if let Some(ref self_p) = self_param {
original_params.push(self_p.clone());
}
for arg in &processed_args {
let pat = &arg.pat;
let ty = &arg.ty;
original_params.push(syn::parse_quote! { #pat: #ty });
}
for arg in &inject_params {
let pat = &arg.pat;
let ty = &arg.ty;
original_params.push(syn::parse_quote! { #pat: #ty });
}
let di_context_extraction = if !inject_params.is_empty() {
quote! {
let __shared_ctx = match #request_pat.get_di_context::<::std::sync::Arc<#di_crate::InjectionContext>>() {
Some(ctx) => ctx,
None => {
::tracing::warn!(
"DI context not set on router. Creating empty fallback context. \
Hint: Configure the router with .with_di_context() for proper dependency injection."
);
::std::sync::Arc::new(::std::sync::Arc::new(
#di_crate::InjectionContext::builder(#di_crate::SingletonScope::new()).build()
))
}
};
let __di_request = #request_pat.clone_for_di();
let __di_ctx = ::std::sync::Arc::new((*__shared_ctx).fork_for_request(__di_request));
let __resolve_ctx = #di_crate::resolve_context::ResolveContext {
root: ::std::sync::Arc::clone(&__shared_ctx),
current: ::std::sync::Arc::clone(&__di_ctx),
};
}
} else {
quote! {}
};
let mut injection_stmts = Vec::new();
for arg in &inject_params {
let pat = &arg.pat;
let ty = &arg.ty;
let injection_code = if arg.use_cache {
quote! {
let #pat: #ty = #di_crate::Depends::<#ty>::resolve(&__di_ctx, true)
.await
.map_err(#core_crate::exception::Error::from)?
.into_inner();
}
} else {
quote! {
let #pat: #ty = #di_crate::Depends::<#ty>::resolve(&__di_ctx, false)
.await
.map_err(#core_crate::exception::Error::from)?
.into_inner();
}
};
injection_stmts.push(injection_code);
}
let call_args: Vec<_> = processed_args
.iter()
.chain(inject_params.iter())
.map(|arg| &arg.pat)
.collect();
let is_method = self_param.is_some();
let original_call = if is_method {
quote! { self.#original_fn_name(#(#call_args),*).await }
} else {
quote! { #original_fn_name(#(#call_args),*).await }
};
let other_param_tokens: Vec<_> = other_params
.iter()
.map(|arg| {
let pat = &arg.pat;
let ty = &arg.ty;
quote! { #pat: #ty }
})
.collect();
let wrapper_params = if is_method {
let self_p = self_param.as_ref().unwrap();
if has_request {
if other_param_tokens.is_empty() {
quote! { #self_p, #request_pat: #http_crate::Request }
} else {
quote! { #self_p, #request_pat: #http_crate::Request, #(#other_param_tokens),* }
}
} else {
if other_param_tokens.is_empty() {
quote! { #self_p, #request_pat: #http_crate::Request }
} else {
quote! { #self_p, #request_pat: #http_crate::Request, #(#other_param_tokens),* }
}
}
} else if has_request {
if other_param_tokens.is_empty() {
quote! { #request_pat: #http_crate::Request }
} else {
quote! { #request_pat: #http_crate::Request, #(#other_param_tokens),* }
}
} else {
if other_param_tokens.is_empty() {
quote! { #request_pat: #http_crate::Request }
} else {
quote! { #request_pat: #http_crate::Request, #(#other_param_tokens),* }
}
};
let handler_body = quote! {
#(#injection_stmts)*
#original_call
};
let scoped_handler_body = if !inject_params.is_empty() {
quote! {
#di_crate::resolve_context::RESOLVE_CTX.scope(__resolve_ctx, async {
#handler_body
}).await
}
} else {
handler_body
};
let expanded = quote! {
#asyncness fn #original_fn_name #generics (#(#original_params),*) #output #where_clause {
#block
}
#(#attrs)*
#vis #asyncness fn #fn_name #generics (#wrapper_params) #output #where_clause {
#di_context_extraction
#scoped_handler_body
}
};
Ok(expanded)
}