1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
use std::iter;

use proc_macro2::Span;
use quote::quote;

#[proc_macro_derive(Erase)]
pub fn derive_erase(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let mut ast: syn::DeriveInput = syn::parse(input).unwrap();

    inject_erase_bound(&mut ast.generics);
    let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
    let generics_filtered: Vec<_> = filter_generics(&ast.generics).collect();
    let generics_filtered = &generics_filtered;
    let erased_generics: Vec<_> = erase_generics(&ast.generics).collect();

    let where_clause = inject_sized_bound(where_clause, &erased_generics[..]);

    let ty = &ast.ident;
    let vis = &ast.vis;
    let erased_ty = syn::Ident::new(&format!("Erased{}", &ty), Span::call_site());

    quote!(
        #vis struct #erased_ty<#(#generics_filtered)*> {
            _marker: ::std::marker::PhantomData<(#(#generics_filtered)*)>,
        }

        unsafe impl #impl_generics eraserhead::Erase for #ty #type_generics #where_clause {
            type Erased = #erased_ty<#(#erased_generics),*>;

            fn erase(&self) -> std::ptr::NonNull<Self::Erased> {
                unsafe { std::mem::transmute(self) }
            }

            unsafe fn unerase<'__unerase>(erased: std::ptr::NonNull<Self::Erased>)
                -> &'__unerase Self
            {
                std::mem::transmute(erased)
            }
        }
    ).into()
}

fn erase_generics(generics: &syn::Generics) -> impl Iterator<Item = syn::Type> + '_ {
    generics.params.iter().filter_map(|param| match param {
        syn::GenericParam::Lifetime(_)  => None,
        syn::GenericParam::Const(_)     => todo!("const generics not supported"),
        syn::GenericParam::Type(param)  => {
            let seg1 = segment(param.ident.clone());
            let seg2 = segment(syn::Ident::new("Erased", Span::call_site()));
            Some(syn::Type::Path(syn::TypePath {
                qself: None,
                path: syn::Path {
                    leading_colon:  None,
                    segments: iter::once(seg1).chain(iter::once(seg2)).collect(),
                },
            }))
        }
    })
}

fn filter_generics(generics: &syn::Generics) -> impl Iterator<Item = syn::TypeParam> + '_ {
    generics.params.iter().cloned().filter_map(|param| match param {
        syn::GenericParam::Type(mut param)  => {
            param.colon_token = None;
            param.bounds = iter::empty::<syn::TypeParamBound>().collect();
            param.eq_token = None;
            param.default = None;
            Some(param)
        }
        _                                   => None,
    })
}

fn inject_erase_bound(generics: &mut syn::Generics) {
    for param in generics.params.iter_mut() {
        if let syn::GenericParam::Type(param) = param {
            let eraserhead = segment(syn::Ident::new("eraserhead", Span::call_site()));
            let erase = segment(syn::Ident::new("Erase", Span::call_site()));
            param.bounds.push(bound(iter::once(eraserhead).chain(iter::once(erase))));
        }
    }
}

fn inject_sized_bound(where_clause: Option<&syn::WhereClause>, types: &[syn::Type])
    -> syn::WhereClause
{
    let mut where_clause = where_clause.cloned().unwrap_or_else(|| syn::WhereClause {
        where_token: syn::token::Where::default(),
        predicates: iter::empty::<syn::WherePredicate>().collect(),
    });

    for ty in types {
        let sized = bound(iter::once(segment(syn::Ident::new("Sized", Span::call_site()))));
        where_clause.predicates.push(syn::WherePredicate::Type(syn::PredicateType {
            lifetimes: None,
            bounded_ty: ty.clone(),
            colon_token: syn::token::Colon::default(),
            bounds: iter::once(sized).collect(),
        }));
    }

    where_clause
}

fn segment(ident: syn::Ident) -> syn::PathSegment {
    syn::PathSegment {
        ident,
        arguments: syn::PathArguments::None,
    }
}

fn bound(segments: impl Iterator<Item = syn::PathSegment>) -> syn::TypeParamBound {
    let segments = segments.collect();
    syn::TypeParamBound::Trait(syn::TraitBound {
        paren_token: None,
        modifier: syn::TraitBoundModifier::None,
        lifetimes: None,
        path: syn::Path {
            leading_colon: None,
            segments,
        }
    })
}