higher-derive 0.1.1

Custom derives for `higher`
Documentation
#![recursion_limit = "256"]

//! Custom derives for the [`higher`][higher] and [`higher-cat`][higher-cat]
//! crates.
//!
//! Please see the relevant crates for documentation on each derive.
//!
//! [higher]: https://docs.rs/crate/higher
//! [higher-cat]: https://docs.rs/crate/higher-cat

extern crate proc_macro;

use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{
    parse_macro_input, parse_quote, punctuated::Punctuated, Data, DataEnum, DeriveInput, Field,
    Fields, FieldsNamed, FieldsUnnamed, GenericParam, Ident, Index, Type, TypeParam,
};

fn replace<A: PartialEq, P>(list: &mut Punctuated<A, P>, target: &A, value: A) {
    for t in list {
        if target == t {
            *t = value;
            return;
        }
    }
    panic!("failed to substitute");
}

#[proc_macro_derive(Lift)]
pub fn derive_lift(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    let name = input.ident;

    let generic_type = input
        .generics
        .type_params()
        .next()
        .expect("can't Lift a type without type parameters");
    let generic_type: GenericParam = generic_type.clone().into();

    let mut impl_vars_2 = input.generics.params.clone();
    impl_vars_2.push(parse_quote!(LiftTarget1));

    let mut impl_vars_3 = impl_vars_2.clone();
    impl_vars_3.push(parse_quote!(LiftTarget2));

    let mut target_vars_2 = input.generics.params.clone();
    replace(&mut target_vars_2, &generic_type, parse_quote!(LiftTarget1));

    let mut target_vars_3 = input.generics.params.clone();
    replace(&mut target_vars_3, &generic_type, parse_quote!(LiftTarget2));

    let (_, type_generics, where_clause) = input.generics.split_for_impl();

    let out = quote! {
        impl<#impl_vars_2> ::higher::Lift<#generic_type, LiftTarget1>
        for #name #type_generics #where_clause {
            type Source = Self;
            type Target1 = #name<#target_vars_2>;
        }

        impl<#impl_vars_3> ::higher::Lift3<#generic_type, LiftTarget2, LiftTarget1>
        for #name #type_generics #where_clause {
            type Target2 = #name<#target_vars_3>;
        }
    };
    proc_macro::TokenStream::from(out)
}

#[proc_macro_derive(Bilift)]
pub fn derive_bilift(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    let name = input.ident;

    let mut types = input.generics.type_params();
    let left_type = types
        .next()
        .expect("can't Bilift a type without type parameters");
    let left_type: GenericParam = left_type.clone().into();
    let right_type = types
        .next()
        .expect("can't Bilift a type with less than two type parameters");
    let right_type: GenericParam = right_type.clone().into();

    let mut impl_vars = input.generics.params.clone();
    impl_vars.push(parse_quote!(LiftTarget1));
    impl_vars.push(parse_quote!(LiftTarget2));

    let mut target_vars = input.generics.params.clone();
    replace(&mut target_vars, &left_type, parse_quote!(LiftTarget1));
    replace(&mut target_vars, &right_type, parse_quote!(LiftTarget2));

    let (_, type_generics, where_clause) = input.generics.split_for_impl();

    let out = quote! {
        impl<#impl_vars> ::higher::Bilift<#left_type, #right_type, LiftTarget1, LiftTarget2>
        for #name #type_generics #where_clause {
            type Source = Self;
            type Target = #name<#target_vars>;
        }
    };
    proc_macro::TokenStream::from(out)
}

fn match_type_param(param: &TypeParam, ty: &Type) -> bool {
    if let Type::Path(path) = ty {
        if let Some(segment) = path.path.segments.iter().next() {
            if segment.ident == param.ident {
                return true;
            }
        }
    }
    false
}

fn filter_fields<'a, P, F1, F2>(
    fields: &'a Punctuated<Field, P>,
    ty: &TypeParam,
    transform: F1,
    copy: F2,
) -> Vec<TokenStream>
where
    F1: Fn(&Ident) -> TokenStream,
    F2: Fn(&Ident) -> TokenStream,
{
    fields
        .iter()
        .map(|field| {
            if match_type_param(ty, &field.ty) {
                transform(&field.ident.clone().unwrap())
            } else {
                copy(&field.ident.clone().unwrap())
            }
        })
        .collect()
}

#[proc_macro_derive(Functor)]
pub fn derive_functor(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    let name = &input.ident;
    let type_params = &input.generics.params;
    let where_clause = &input.generics.where_clause;
    let generic_type = input
        .generics
        .type_params()
        .next()
        .expect("can't derive Functor for a type without type parameters");
    let map_impl = match &input.data {
        Data::Struct(data) => match &data.fields {
            Fields::Named(fields) => derive_functor_named_struct(name, fields, &generic_type),
            Fields::Unnamed(fields) => derive_functor_unnamed_struct(name, fields, &generic_type),
            Fields::Unit => panic!("can't derive Functor for an empty struct"),
        },
        Data::Enum(data) => derive_functor_enum(name, data, &generic_type),
        Data::Union(_) => panic!("can't derive Functor for a union type"),
    };
    quote!(
        impl<#type_params, TargetType> ::higher_cat::Functor<#generic_type, TargetType> for #name<#type_params> #where_clause {
            fn map<F>(self, f: F) -> <Self as ::higher::Lift<#generic_type, TargetType>>::Target1
            where
                F: Fn(#generic_type) -> TargetType
            {
                #map_impl
            }
        }
    ).into()
}

fn derive_functor_named_struct(
    name: &Ident,
    fields: &FieldsNamed,
    generic_type: &TypeParam,
) -> TokenStream {
    let apply_fields = filter_fields(
        &fields.named,
        generic_type,
        |field| {
            quote! {
                #field: f(self.#field),
            }
        },
        |field| {
            quote! {
                #field: self.#field,
            }
        },
    )
    .into_iter();
    quote! {
        #name {
            #(#apply_fields)*
        }
    }
}

fn derive_functor_unnamed_struct(
    name: &Ident,
    fields: &FieldsUnnamed,
    generic_type: &TypeParam,
) -> TokenStream {
    let fields = fields.unnamed.iter().enumerate().map(|(index, field)| {
        let index = Index::from(index);
        if match_type_param(generic_type, &field.ty) {
            quote! { f(self.#index), }
        } else {
            quote! { self.#index, }
        }
    });
    quote! { #name(#(#fields)*) }
}

fn derive_functor_enum(name: &Ident, data: &DataEnum, generic_type: &TypeParam) -> TokenStream {
    let variants = data.variants.iter().map(|variant| {
        let ident = &variant.ident;
        match &variant.fields {
            Fields::Named(fields) => {
                let args: Vec<Ident> = fields
                    .named
                    .iter()
                    .map(|field| {
                        Ident::new(
                            &format!("arg_{}", field.ident.clone().unwrap()),
                            field.ident.clone().unwrap().span(),
                        )
                    })
                    .collect();
                let apply =
                    fields
                        .named
                        .iter()
                        .zip(args.clone().into_iter())
                        .map(|(field, arg)| {
                            let name = &field.ident;
                            if match_type_param(generic_type, &field.ty) {
                                quote! { #name: f(#arg) }
                            } else {
                                quote! { #name: #arg }
                            }
                        });
                let args = fields
                    .named
                    .iter()
                    .zip(args.into_iter())
                    .map(|(field, arg)| {
                        let name = &field.ident;
                        quote! { #name:#arg }
                    });
                quote! {
                    #name::#ident { #(#args,)* } => #name::#ident { #(#apply,)* },
                }
            }
            Fields::Unnamed(fields) => {
                let args: Vec<Ident> = fields
                    .unnamed
                    .iter()
                    .enumerate()
                    .map(|(index, _)| Ident::new(&format!("arg{}", index), Span::call_site()))
                    .collect();
                let fields = fields.unnamed.iter().zip(args.iter()).map(|(field, arg)| {
                    if match_type_param(generic_type, &field.ty) {
                        quote! { f(#arg) }
                    } else {
                        quote! { #arg }
                    }
                });
                let args = args.iter();
                quote! {
                    #name::#ident(#(#args,)*) => #name::#ident(#(#fields,)*),
                }
            }
            Fields::Unit => quote! {
                #name::#ident => #name::#ident,
            },
        }
    });
    quote! {
        match self {
            #(#variants)*
        }
    }
}