use proc_macro2::TokenStream;
use quote::quote;
use syn::{ItemImpl, ItemTrait, Path, Token, parse::Parse, parse::ParseStream};
pub mod factory;
pub mod naming;
pub mod params;
pub mod type_map;
pub mod vtable;
use naming::factory_struct_name;
use params::MethodInfo;
use type_map::TypeMap;
struct ReflectTraitAttr {
trait_path: Path,
type_map: TypeMap,
}
impl Parse for ReflectTraitAttr {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let trait_path: Path = input.parse()?;
let mut type_map = TypeMap::default();
if input.peek(Token![,]) {
let _: Token![,] = input.parse()?;
let kw: syn::Ident = input.parse()?;
if kw != "type_map" {
return Err(syn::Error::new_spanned(
kw,
"#[reflect_trait]: expected `type_map(...)` after trait path",
));
}
type_map = input.parse::<TypeMap>()?;
}
Ok(ReflectTraitAttr {
trait_path,
type_map,
})
}
}
pub fn expand(attr: TokenStream, item: TokenStream) -> syn::Result<TokenStream> {
let ReflectTraitAttr {
trait_path,
type_map,
} = syn::parse2::<ReflectTraitAttr>(attr)?;
let trait_path_str = path_to_string(&trait_path);
let methods: Vec<MethodInfo> = if let Ok(trait_item) = syn::parse2::<ItemTrait>(item.clone()) {
let vis = &trait_item.vis;
let _ = vis;
MethodInfo::from_trait_items(&trait_item.items)?
} else {
let impl_block = syn::parse2::<ItemImpl>(item)?;
let vis = &impl_block.self_ty;
let _ = vis;
MethodInfo::from_impl_items(&impl_block.items)?
};
if methods.is_empty() {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"#[reflect_trait]: trait/impl block must contain at least one method signature",
));
}
let gen_vis: syn::Visibility = syn::parse_quote!(pub);
let factory_description = format!("Tools for types implementing `{trait_path_str}`");
let param_structs: Vec<TokenStream> = methods
.iter()
.map(|m| m.param_struct_tokens(&gen_vis, &type_map))
.collect();
let vtable_ts =
vtable::vtable_tokens(&trait_path, &trait_path_str, &methods, &gen_vis, &type_map);
let factory_ts = factory::factory_tokens(
&trait_path,
&trait_path_str,
&factory_description,
&methods,
&gen_vis,
);
let factory_name = factory_struct_name(&trait_path_str);
let prime_fn_name = proc_macro2::Ident::new(
&format!(
"prime_{}",
naming::to_snake_path(&trait_path_str).replace("::", "__")
),
proc_macro2::Span::call_site(),
);
Ok(quote! {
#(#param_structs)*
#vtable_ts
#factory_ts
pub fn #prime_fn_name<T>()
where
T: #trait_path
+ ::serde::Serialize
+ ::serde::de::DeserializeOwned
+ ::schemars::JsonSchema
+ ::elicitation::Elicitation
+ Send + Sync + 'static,
{
#factory_name::prime::<T>();
}
})
}
fn path_to_string(path: &Path) -> String {
path.segments
.iter()
.map(|s| s.ident.to_string())
.collect::<Vec<_>>()
.join("::")
}