use std::iter;
use proc_macro2::Span;
use quote::quote;
#[proc_macro_derive(Erase)]
pub fn derive_erase(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let mut ast: syn::DeriveInput = syn::parse(input).unwrap();
inject_erase_bound(&mut ast.generics);
let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
let generics_filtered: Vec<_> = filter_generics(&ast.generics).collect();
let generics_filtered = &generics_filtered;
let erased_generics: Vec<_> = erase_generics(&ast.generics).collect();
let where_clause = inject_sized_bound(where_clause, &erased_generics[..]);
let ty = &ast.ident;
let vis = &ast.vis;
let erased_ty = syn::Ident::new(&format!("Erased{}", &ty), Span::call_site());
quote!(
#vis struct #erased_ty<#(#generics_filtered)*> {
_marker: ::std::marker::PhantomData<(#(#generics_filtered)*)>,
}
unsafe impl #impl_generics eraserhead::Erase for #ty #type_generics #where_clause {
type Erased = #erased_ty<#(#erased_generics),*>;
fn erase(&self) -> std::ptr::NonNull<Self::Erased> {
unsafe { std::mem::transmute(self) }
}
unsafe fn unerase<'__unerase>(erased: std::ptr::NonNull<Self::Erased>)
-> &'__unerase Self
{
std::mem::transmute(erased)
}
}
).into()
}
fn erase_generics(generics: &syn::Generics) -> impl Iterator<Item = syn::Type> + '_ {
generics.params.iter().filter_map(|param| match param {
syn::GenericParam::Lifetime(_) => None,
syn::GenericParam::Const(_) => todo!("const generics not supported"),
syn::GenericParam::Type(param) => {
let seg1 = segment(param.ident.clone());
let seg2 = segment(syn::Ident::new("Erased", Span::call_site()));
Some(syn::Type::Path(syn::TypePath {
qself: None,
path: syn::Path {
leading_colon: None,
segments: iter::once(seg1).chain(iter::once(seg2)).collect(),
},
}))
}
})
}
fn filter_generics(generics: &syn::Generics) -> impl Iterator<Item = syn::TypeParam> + '_ {
generics.params.iter().cloned().filter_map(|param| match param {
syn::GenericParam::Type(mut param) => {
param.colon_token = None;
param.bounds = iter::empty::<syn::TypeParamBound>().collect();
param.eq_token = None;
param.default = None;
Some(param)
}
_ => None,
})
}
fn inject_erase_bound(generics: &mut syn::Generics) {
for param in generics.params.iter_mut() {
if let syn::GenericParam::Type(param) = param {
let eraserhead = segment(syn::Ident::new("eraserhead", Span::call_site()));
let erase = segment(syn::Ident::new("Erase", Span::call_site()));
param.bounds.push(bound(iter::once(eraserhead).chain(iter::once(erase))));
}
}
}
fn inject_sized_bound(where_clause: Option<&syn::WhereClause>, types: &[syn::Type])
-> syn::WhereClause
{
let mut where_clause = where_clause.cloned().unwrap_or_else(|| syn::WhereClause {
where_token: syn::token::Where::default(),
predicates: iter::empty::<syn::WherePredicate>().collect(),
});
for ty in types {
let sized = bound(iter::once(segment(syn::Ident::new("Sized", Span::call_site()))));
where_clause.predicates.push(syn::WherePredicate::Type(syn::PredicateType {
lifetimes: None,
bounded_ty: ty.clone(),
colon_token: syn::token::Colon::default(),
bounds: iter::once(sized).collect(),
}));
}
where_clause
}
fn segment(ident: syn::Ident) -> syn::PathSegment {
syn::PathSegment {
ident,
arguments: syn::PathArguments::None,
}
}
fn bound(segments: impl Iterator<Item = syn::PathSegment>) -> syn::TypeParamBound {
let segments = segments.collect();
syn::TypeParamBound::Trait(syn::TraitBound {
paren_token: None,
modifier: syn::TraitBoundModifier::None,
lifetimes: None,
path: syn::Path {
leading_colon: None,
segments,
}
})
}