use proc_macro::TokenStream;
use quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::{parse_macro_input, Ident, ItemTrait, Path, Token};
struct ServiceArgs {
impl_type: Option<Path>,
fake_type: Option<Path>,
}
impl Parse for ServiceArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut impl_type = None;
let mut fake_type = None;
if input.is_empty() {
return Ok(ServiceArgs {
impl_type: None,
fake_type: None,
});
}
let fork = input.fork();
let is_named = if fork.parse::<Ident>().is_ok() {
fork.peek(Token![=])
} else {
false
};
if is_named {
while !input.is_empty() {
let name: Ident = input.parse()?;
input.parse::<Token![=]>()?;
let path: Path = input.parse()?;
match name.to_string().as_str() {
"impl" => impl_type = Some(path),
"fake" => fake_type = Some(path),
_ => {
return Err(syn::Error::new(
name.span(),
format!("unknown parameter '{name}', expected 'impl' or 'fake'"),
))
}
}
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
}
}
} else {
impl_type = Some(input.parse()?);
}
Ok(ServiceArgs {
impl_type,
fake_type,
})
}
}
pub fn service_impl(attr: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as ServiceArgs);
let mut item_trait = parse_macro_input!(input as ItemTrait);
let ferro = quote!(::ferro);
let send_bound: syn::TypeParamBound = syn::parse_quote!(Send);
let sync_bound: syn::TypeParamBound = syn::parse_quote!(Sync);
let static_bound: syn::TypeParamBound = syn::parse_quote!('static);
let has_send = item_trait.supertraits.iter().any(|bound| {
if let syn::TypeParamBound::Trait(trait_bound) = bound {
trait_bound
.path
.segments
.last()
.map(|s| s.ident == "Send")
.unwrap_or(false)
} else {
false
}
});
let has_sync = item_trait.supertraits.iter().any(|bound| {
if let syn::TypeParamBound::Trait(trait_bound) = bound {
trait_bound
.path
.segments
.last()
.map(|s| s.ident == "Sync")
.unwrap_or(false)
} else {
false
}
});
let has_static = item_trait
.supertraits
.iter()
.any(|bound| matches!(bound, syn::TypeParamBound::Lifetime(lt) if lt.ident == "static"));
if !has_send {
item_trait.supertraits.push(send_bound);
}
if !has_sync {
item_trait.supertraits.push(sync_bound);
}
if !has_static {
item_trait.supertraits.push(static_bound);
}
let trait_name = &item_trait.ident;
let trait_name_str = trait_name.to_string();
let impl_registration = args.impl_type.as_ref().map(|concrete_type| {
quote! {
#ferro::inventory::submit! {
#ferro::container::provider::ServiceBindingEntry {
register: || {
#ferro::App::bind::<dyn #trait_name>(
::std::sync::Arc::new(<#concrete_type as ::std::default::Default>::default())
);
},
name: #trait_name_str,
}
}
}
});
let fake_impl = args.fake_type.as_ref().map(|fake_type| {
quote! {
impl dyn #trait_name {
pub fn fake() -> #ferro::container::testing::TestContainerGuard {
let guard = #ferro::container::testing::TestContainer::fake();
#ferro::container::testing::TestContainer::bind::<dyn #trait_name>(
::std::sync::Arc::new(<#fake_type as ::std::default::Default>::default())
);
guard
}
}
}
});
let expanded = quote! {
#item_trait
#impl_registration
#fake_impl
};
TokenStream::from(expanded)
}