1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
use std::iter; use proc_macro2::Span; use quote::quote; #[proc_macro_derive(Erase)] pub fn derive_erase(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let mut ast: syn::DeriveInput = syn::parse(input).unwrap(); inject_erase_bound(&mut ast.generics); let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl(); let generics_filtered: Vec<_> = filter_generics(&ast.generics).collect(); let generics_filtered = &generics_filtered; let erased_generics: Vec<_> = erase_generics(&ast.generics).collect(); let where_clause = inject_sized_bound(where_clause, &erased_generics[..]); let ty = &ast.ident; let vis = &ast.vis; let erased_ty = syn::Ident::new(&format!("Erased{}", &ty), Span::call_site()); quote!( #vis struct #erased_ty<#(#generics_filtered)*> { _marker: ::std::marker::PhantomData<(#(#generics_filtered)*)>, } unsafe impl #impl_generics eraserhead::Erase for #ty #type_generics #where_clause { type Erased = #erased_ty<#(#erased_generics),*>; fn erase(&self) -> std::ptr::NonNull<Self::Erased> { unsafe { std::mem::transmute(self) } } unsafe fn unerase<'__unerase>(erased: std::ptr::NonNull<Self::Erased>) -> &'__unerase Self { std::mem::transmute(erased) } } ).into() } fn erase_generics(generics: &syn::Generics) -> impl Iterator<Item = syn::Type> + '_ { generics.params.iter().filter_map(|param| match param { syn::GenericParam::Lifetime(_) => None, syn::GenericParam::Const(_) => todo!("const generics not supported"), syn::GenericParam::Type(param) => { let seg1 = segment(param.ident.clone()); let seg2 = segment(syn::Ident::new("Erased", Span::call_site())); Some(syn::Type::Path(syn::TypePath { qself: None, path: syn::Path { leading_colon: None, segments: iter::once(seg1).chain(iter::once(seg2)).collect(), }, })) } }) } fn filter_generics(generics: &syn::Generics) -> impl Iterator<Item = syn::TypeParam> + '_ { generics.params.iter().cloned().filter_map(|param| match param { syn::GenericParam::Type(mut param) => { param.colon_token = None; param.bounds = iter::empty::<syn::TypeParamBound>().collect(); param.eq_token = None; param.default = None; Some(param) } _ => None, }) } fn inject_erase_bound(generics: &mut syn::Generics) { for param in generics.params.iter_mut() { if let syn::GenericParam::Type(param) = param { let eraserhead = segment(syn::Ident::new("eraserhead", Span::call_site())); let erase = segment(syn::Ident::new("Erase", Span::call_site())); param.bounds.push(bound(iter::once(eraserhead).chain(iter::once(erase)))); } } } fn inject_sized_bound(where_clause: Option<&syn::WhereClause>, types: &[syn::Type]) -> syn::WhereClause { let mut where_clause = where_clause.cloned().unwrap_or_else(|| syn::WhereClause { where_token: syn::token::Where::default(), predicates: iter::empty::<syn::WherePredicate>().collect(), }); for ty in types { let sized = bound(iter::once(segment(syn::Ident::new("Sized", Span::call_site())))); where_clause.predicates.push(syn::WherePredicate::Type(syn::PredicateType { lifetimes: None, bounded_ty: ty.clone(), colon_token: syn::token::Colon::default(), bounds: iter::once(sized).collect(), })); } where_clause } fn segment(ident: syn::Ident) -> syn::PathSegment { syn::PathSegment { ident, arguments: syn::PathArguments::None, } } fn bound(segments: impl Iterator<Item = syn::PathSegment>) -> syn::TypeParamBound { let segments = segments.collect(); syn::TypeParamBound::Trait(syn::TraitBound { paren_token: None, modifier: syn::TraitBoundModifier::None, lifetimes: None, path: syn::Path { leading_colon: None, segments, } }) }