eraserhead-derive 1.0.0

Derive for Erase trait in eraserhead crate
Documentation
use std::iter;

use proc_macro2::Span;
use quote::quote;

#[proc_macro_derive(Erase)]
pub fn derive_erase(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let mut ast: syn::DeriveInput = syn::parse(input).unwrap();

    inject_erase_bound(&mut ast.generics);
    let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
    let generics_filtered: Vec<_> = filter_generics(&ast.generics).collect();
    let generics_filtered = &generics_filtered;
    let erased_generics: Vec<_> = erase_generics(&ast.generics).collect();

    let where_clause = inject_sized_bound(where_clause, &erased_generics[..]);

    let ty = &ast.ident;
    let vis = &ast.vis;
    let erased_ty = syn::Ident::new(&format!("Erased{}", &ty), Span::call_site());

    quote!(
        #vis struct #erased_ty<#(#generics_filtered)*> {
            _marker: ::std::marker::PhantomData<(#(#generics_filtered)*)>,
        }

        unsafe impl #impl_generics eraserhead::Erase for #ty #type_generics #where_clause {
            type Erased = #erased_ty<#(#erased_generics),*>;

            fn erase(&self) -> std::ptr::NonNull<Self::Erased> {
                unsafe { std::mem::transmute(self) }
            }

            unsafe fn unerase<'__unerase>(erased: std::ptr::NonNull<Self::Erased>)
                -> &'__unerase Self
            {
                std::mem::transmute(erased)
            }
        }
    ).into()
}

fn erase_generics(generics: &syn::Generics) -> impl Iterator<Item = syn::Type> + '_ {
    generics.params.iter().filter_map(|param| match param {
        syn::GenericParam::Lifetime(_)  => None,
        syn::GenericParam::Const(_)     => todo!("const generics not supported"),
        syn::GenericParam::Type(param)  => {
            let seg1 = segment(param.ident.clone());
            let seg2 = segment(syn::Ident::new("Erased", Span::call_site()));
            Some(syn::Type::Path(syn::TypePath {
                qself: None,
                path: syn::Path {
                    leading_colon:  None,
                    segments: iter::once(seg1).chain(iter::once(seg2)).collect(),
                },
            }))
        }
    })
}

fn filter_generics(generics: &syn::Generics) -> impl Iterator<Item = syn::TypeParam> + '_ {
    generics.params.iter().cloned().filter_map(|param| match param {
        syn::GenericParam::Type(mut param)  => {
            param.colon_token = None;
            param.bounds = iter::empty::<syn::TypeParamBound>().collect();
            param.eq_token = None;
            param.default = None;
            Some(param)
        }
        _                                   => None,
    })
}

fn inject_erase_bound(generics: &mut syn::Generics) {
    for param in generics.params.iter_mut() {
        if let syn::GenericParam::Type(param) = param {
            let eraserhead = segment(syn::Ident::new("eraserhead", Span::call_site()));
            let erase = segment(syn::Ident::new("Erase", Span::call_site()));
            param.bounds.push(bound(iter::once(eraserhead).chain(iter::once(erase))));
        }
    }
}

fn inject_sized_bound(where_clause: Option<&syn::WhereClause>, types: &[syn::Type])
    -> syn::WhereClause
{
    let mut where_clause = where_clause.cloned().unwrap_or_else(|| syn::WhereClause {
        where_token: syn::token::Where::default(),
        predicates: iter::empty::<syn::WherePredicate>().collect(),
    });

    for ty in types {
        let sized = bound(iter::once(segment(syn::Ident::new("Sized", Span::call_site()))));
        where_clause.predicates.push(syn::WherePredicate::Type(syn::PredicateType {
            lifetimes: None,
            bounded_ty: ty.clone(),
            colon_token: syn::token::Colon::default(),
            bounds: iter::once(sized).collect(),
        }));
    }

    where_clause
}

fn segment(ident: syn::Ident) -> syn::PathSegment {
    syn::PathSegment {
        ident,
        arguments: syn::PathArguments::None,
    }
}

fn bound(segments: impl Iterator<Item = syn::PathSegment>) -> syn::TypeParamBound {
    let segments = segments.collect();
    syn::TypeParamBound::Trait(syn::TraitBound {
        paren_token: None,
        modifier: syn::TraitBoundModifier::None,
        lifetimes: None,
        path: syn::Path {
            leading_colon: None,
            segments,
        }
    })
}