use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Attribute, Data, DeriveInput, Fields, LitStr, Type};
use crate::common::krate;
fn is_logger(ty: &Type) -> bool {
matches!(ty, Type::Path(path) if path.path.segments.last().is_some_and(|s| s.ident == "Logger"))
}
fn context_attr(attrs: &[Attribute], name: &str) -> syn::Result<Option<String>> {
let Some(attr) = attrs.iter().find(|attr| attr.path().is_ident(name)) else {
return Ok(None);
};
let mut context = None;
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("context") {
let value: LitStr = meta.value()?.parse()?;
context = Some(value.value());
Ok(())
} else {
Err(meta.error("expected `context = \"...\"`"))
}
})?;
Ok(context)
}
pub fn expand(item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as DeriveInput);
match expand_derive(input) {
Ok(tokens) => tokens.into(),
Err(error) => error.to_compile_error().into(),
}
}
fn expand_derive(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let fields = match &input.data {
Data::Struct(data) => match &data.fields {
Fields::Named(named) => &named.named,
_ => {
return Err(syn::Error::new_spanned(
&input,
"#[derive(Inject)] requires a struct with named fields",
));
}
},
_ => {
return Err(syn::Error::new_spanned(
&input,
"#[derive(Inject)] can only be derived for structs",
));
}
};
let krate = krate();
let ident = &input.ident;
let container_context = context_attr(&input.attrs, "inject")?;
let mut bindings = Vec::new();
let mut names = Vec::new();
for field in fields {
let field_ident = field.ident.as_ref().expect("named field");
let field_ty = &field.ty;
if is_logger(field_ty) {
let context = context_attr(&field.attrs, "logger")?
.or_else(|| container_context.clone())
.unwrap_or_else(|| ident.to_string());
bindings.push(quote! {
let #field_ident = <#field_ty as #krate::FromRequest>::from_request(ctx)
.await?
.for_context(#context);
});
} else {
bindings.push(quote! {
let #field_ident = <#field_ty as #krate::FromRequest>::from_request(ctx).await?;
});
}
names.push(field_ident);
}
Ok(quote! {
impl #krate::FromRequest for #ident {
fn from_request(
ctx: & #krate::RequestContext,
) -> impl ::core::future::Future<Output = #krate::Result<Self>> + Send {
async move {
if let ::core::option::Option::Some(__overridden) =
#krate::__take_override::<Self>(ctx)
{
return ::core::result::Result::Ok(__overridden);
}
#(#bindings)*
::core::result::Result::Ok(#ident { #(#names),* })
}
}
}
})
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn is_logger_detects_final_segment() {
let logger: Type = parse_quote!(Logger);
let nested: Type = parse_quote!(crate::logging::Logger);
let other: Type = parse_quote!(Db);
assert!(is_logger(&logger));
assert!(is_logger(&nested));
assert!(!is_logger(&other));
}
#[test]
fn context_attr_reads_known_value_and_rejects_unknown_keys() {
let attrs: Vec<Attribute> = vec![parse_quote!(#[inject(context = "api")])];
assert_eq!(
context_attr(&attrs, "inject").unwrap(),
Some("api".to_owned())
);
assert_eq!(context_attr(&[], "inject").unwrap(), None);
let attrs: Vec<Attribute> = vec![parse_quote!(#[inject(foo = "bar")])];
let err = context_attr(&attrs, "inject").unwrap_err();
assert!(err.to_string().contains("expected `context = \"...\"`"));
}
#[test]
fn expand_derive_rejects_invalid_struct_shapes() {
let input: DeriveInput = parse_quote!(
enum NotInjectable {
A,
}
);
assert!(expand_derive(input)
.unwrap_err()
.to_string()
.contains("only be derived for structs"));
let input: DeriveInput = parse_quote! {
struct Tuple(Logger);
};
assert!(expand_derive(input)
.unwrap_err()
.to_string()
.contains("named fields"));
}
#[test]
fn expand_derive_uses_logger_context_precedence_and_override() {
let input: DeriveInput = parse_quote! {
#[inject(context = "container")]
struct Service {
db: Db,
#[logger(context = "field")]
logger: Logger,
}
};
let tokens = expand_derive(input).unwrap().to_string();
assert!(tokens.contains("__take_override"));
assert!(tokens.contains("FromRequest for Service"));
assert!(tokens.contains("let db ="));
assert!(tokens.contains("let logger ="));
assert!(tokens.contains("for_context"));
assert!(tokens.contains("field"));
}
}