hacky-types 0.1.0

A crate contains workarounds for rust type system
Documentation
extern crate proc_macro;

use proc_macro2::{Ident, Span};
use quote::quote;
use std::fmt::Display;
use syn::{
    parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma,
    AttributeArgs, FnArg, Generics, ItemFn, Lit, Meta, NestedMeta, Pat, Type, TypeParam,
};

pub(crate) fn main(
    attr: proc_macro::TokenStream,
    item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
    let sa = parse_macro_input!(attr as AttributeArgs);
    let si = parse_macro_input!(item as ItemFn);
    let name = &si.sig.ident;
    let mut trait_name = Ident::new(&format!("IntoExtensionMethod{}", &name), name.span());
    let mut impl_type_param = Ident::new("FirstArgImplType Param", Span::call_site());
    let mut remove_fn = false;

    for item in sa {
        if let NestedMeta::Meta(Meta::NameValue(x)) = item {
            if x.path.is_ident("trait_name") {
                if let Lit::Str(lstr) = x.lit {
                    trait_name = Ident::new(&lstr.value(), lstr.span());
                } else {
                    return syn_err(
                        x.lit.span(),
                        "expected string literal (eg: `\"IntoMyExtensionMethod\"`)",
                    );
                }
            } else if x.path.is_ident("impl_type_param") {
                if let Lit::Str(lstr) = x.lit {
                    impl_type_param = Ident::new(&lstr.value(), lstr.span());
                } else {
                    return syn_err(
                        x.lit.span(),
                        "expected string literal (eg: `\"FirstArgImplTypeParam\"`)",
                    );
                }
            } else if x.path.is_ident("remove_fn") {
                if let Lit::Bool(lbool) = x.lit {
                    remove_fn = lbool.value;
                } else {
                    return syn_err(x.lit.span(), "expected bool literal (eg: `true`)");
                }
            }
        }
    }

    let vis_trait = &si.vis;
    let mut generics = si.sig.generics.clone();
    let kw_async = si.sig.asyncness;
    let kw_const = si.sig.constness;
    let kw_unsafe = si.sig.unsafety;

    let mut first_arg = if let Some(arg) = si.sig.inputs.first() {
        arg.clone()
    } else {
        return syn_err(
            si.sig.paren_token.span,
            "extension method must have at least 1 input",
        );
    };
    let mut first_arg_pt = if let FnArg::Typed(arg) = &first_arg {
        arg
    } else {
        return syn_err(first_arg.span(), "first argument must not be `self`");
    };

    if let Type::Path(p) = first_arg_pt.ty.as_ref() {
        if p.path.is_ident("Self") {
            return syn_err(p.span(), "first argument cannot be of type `Self`");
        }
    }

    let (fa_mut, fa_ident) = if let Pat::Ident(ident) = first_arg_pt.pat.as_ref() {
        (&ident.mutability, &ident.ident)
    } else {
        return syn_err(first_arg_pt.pat.span(), "expected identifier");
    };

    if let Type::ImplTrait(p) = first_arg_pt.ty.as_ref() {
        let fa_impl_bounds = &p.bounds;
        generics
            .params
            .push(parse_quote!(#impl_type_param: #fa_impl_bounds));

        // Continue as typeParams
        let attrs = &first_arg_pt.attrs;
        first_arg = parse_quote!(#(#attrs)* #fa_mut #fa_ident: #impl_type_param);
        first_arg_pt = if let FnArg::Typed(pt) = &first_arg {
            pt
        } else {
            panic!("Bro how")
        }
    }

    let trait_tp = if let Some(p) = is_type_param(&first_arg_pt.ty, &generics) {
        Some((p, &None))
    } else if let Type::Ptr(ptr) = first_arg_pt.ty.as_ref() {
        is_type_param(&ptr.elem, &generics).map(|x| (x, &None))
    } else if let Type::Reference(rf) = first_arg_pt.ty.as_ref() {
        is_type_param(&rf.elem, &generics).map(|x| (x, &rf.lifetime))
    } else {
        None
    };

    let trait_generics: Option<Generics> = trait_tp.map(|(tp, lt)| {
        if let Some(lt) = lt {
            parse_quote!(<#lt, #tp>)
        } else {
            parse_quote!(<#tp>)
        }
    });

    let remain_generics = if let Some((tp, lto)) = trait_tp {
        let p: Punctuated<_, Comma> = generics
            .params
            .iter()
            .filter(|x| match x {
                syn::GenericParam::Type(t) => t.ident != tp.ident,
                syn::GenericParam::Lifetime(l) => {
                    if let Some(lt) = lto {
                        l.lifetime.ident != lt.ident
                    } else {
                        false
                    }
                }
                _ => true,
            })
            .collect();

        parse_quote!(<#p>)
    } else {
        generics.clone()
    };

    let trait_doc =
        format!("this trait is autogenerated for extension method `obj.{name}(...)` to work");
    let fn_attrs = &si.attrs;
    let first_arg_attrs = &first_arg_pt.attrs;
    let fn_out = &si.sig.output;
    let wcl = &generics.where_clause;
    let inp: Punctuated<_, Comma> = si.sig.inputs.iter().skip(1).collect();

    let trait_code = quote! {
        #[doc = #trait_doc]
        #[allow(non_camel_case_types)]
        #vis_trait trait #trait_name #trait_generics {
            #(#fn_attrs)*
            #kw_const #kw_unsafe #kw_async fn #name #remain_generics(#(#first_arg_attrs)* self, #inp) #fn_out #wcl;
        }
    };

    let impl_ty = first_arg_pt.ty.as_ref();
    let stmts = &si.block.stmts;
    let fa_pat = &first_arg_pt.pat;

    let impl_code = quote! {
        impl #trait_generics #trait_name #trait_generics for #impl_ty {
            #(#fn_attrs)*
            #kw_const #kw_unsafe #kw_async fn #name #remain_generics(#(#first_arg_attrs)* self, #inp) #fn_out #wcl {
                let #fa_pat = self;
                #(#stmts)*
            }
        }
    };

    let original_fn = match remove_fn {
        true => None,
        false => Some(&si),
    };

    quote! {
        #trait_code

        #impl_code

        #original_fn
    }
    .into()
}

fn is_type_param<'a>(ty: &Type, tps: &'a Generics) -> Option<&'a TypeParam> {
    if let Type::Path(p) = ty {
        tps.type_params().find(|x| p.path.is_ident(&x.ident))
    } else {
        None
    }
}

fn syn_err<T: Display>(span: proc_macro2::Span, message: T) -> proc_macro::TokenStream {
    syn::Error::new(span, message).into_compile_error().into()
}