enum_delegate_lib 0.2.0

Internal macro implementations for enum_delegate - use to implement your own macros
Documentation
//! Code generation for implementing a trait for an enum

use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::spanned::Spanned;
use syn::{FnArg, ItemTrait, Signature, TraitItem, TraitItemMethod, TraitItemType};

use crate::error::InvalidInput;
use crate::generate_delegation::MacroContext;
use crate::input::AssociatedTypeUnification;

/// Generate an `impl` block that delegates the trait for the enum
pub(crate) fn generate_trait_impl(context: &MacroContext) -> Result<TokenStream, InvalidInput> {
    let MacroContext {
        trait_path,
        parsed_trait,
        enum_path,
        ..
    } = *context;

    validate_trait(parsed_trait)?;

    let item_implementations: Result<Vec<_>, _> = parsed_trait
        .items
        .iter()
        .map(|item| match item {
            TraitItem::Method(method) => generate_fn_implementation(context, method),
            TraitItem::Type(type_) => Ok(generate_type_implementation(context, type_)),
            other => Err(InvalidInput::UnsupportedTraitItem(other.span())),
        })
        .collect();
    let item_implementations = item_implementations?;

    let associated_type_equality_guard = generate_associated_type_equality_guards(context);

    Ok(quote! {
        impl #trait_path for #enum_path where #associated_type_equality_guard {
            #(#item_implementations)*
        }
    })
}

fn validate_trait(trait_: &ItemTrait) -> Result<(), InvalidInput> {
    if !trait_.generics.params.is_empty() {
        return Err(InvalidInput::UnsupportedFeature(
            trait_.generics.span(),
            "generics",
        ));
    }

    if !trait_.supertraits.is_empty() {
        return Err(InvalidInput::UnsupportedFeature(
            trait_.supertraits.span(),
            "supertraits",
        ));
    }

    Ok(())
}

/// Generate guards for the where clause to make sure that the associated types match if they need to
fn generate_associated_type_equality_guards(context: &MacroContext) -> TokenStream {
    if context.parsed_enum.variants().len() < 2 {
        return TokenStream::new();
    }

    let trait_name = context.trait_path;
    let helper_mod = &context.helper_mod_name;

    let types_needing_guard = context
        .associated_type_config
        .iter()
        .filter(|(_, value)| matches!(value, AssociatedTypeUnification::Same))
        .map(|(type_ident, _)| type_ident);

    let guards = types_needing_guard.flat_map(|type_ident| {
        let mut variants = context.parsed_enum.variants().iter();

        let first_variant_type = &variants.next().unwrap().type_;
        let remaining_variants = variants;

        remaining_variants.map(move |variant| {
            let other_variant_type = &variant.type_;

            quote! {
                (
                    <#first_variant_type as #trait_name>::#type_ident,
                    <#other_variant_type as #trait_name>::#type_ident
                ): #helper_mod::EqualTypes
            }
        })
    });

    quote! {
        #(#guards,)*
    }
}

/// A type declaration inside the impl block.
fn generate_type_implementation(context: &MacroContext, type_: &TraitItemType) -> TokenStream {
    let config = context
        .associated_type_config
        .get(&type_.ident)
        .expect("All associated types should have a parsed type configuration.");

    match config {
        AssociatedTypeUnification::Same => redeclare_first_variant_associated_type(context, type_),
        AssociatedTypeUnification::EnumWrap => {
            let corresponding_enum = context
                .associated_type_enum_names
                .get(&type_.ident)
                .expect("All EnumWrap types should have an entry in associated_type_enum_names");
            let type_ident = &type_.ident;

            quote!(
                type #type_ident = #corresponding_enum;
            )
        }
    }
}

/// A type declaration that re-uses the type of the first variant
fn redeclare_first_variant_associated_type(
    context: &MacroContext,
    type_: &TraitItemType,
) -> TokenStream {
    let type_ident = &type_.ident;
    let trait_path = context.trait_path;
    let variant_type = &context.parsed_enum.first_variant().type_;

    quote!(
        type #type_ident = <#variant_type as #trait_path>::#type_ident;
    )
}

/// Generate an `fn` for the enum's impl block
///
/// This will match against self and delegate to the appropriate field
fn generate_fn_implementation(
    context: &MacroContext,
    method: &TraitItemMethod,
) -> Result<TokenStream, InvalidInput> {
    let MacroContext {
        trait_path,
        enum_path,
        ..
    } = *context;
    let Signature {
        ident: method_name,
        inputs,
        ..
    } = &method.sig;

    let arguments_to_pass: Vec<_> = inputs
        .iter()
        .skip(1)
        .map(|input| match input {
            FnArg::Receiver(_) => {
                unreachable!("only first argument, which we skipped, can be receiver")
            }

            FnArg::Typed(t) => {
                let argument = &t.pat;

                quote!(#argument.try_into().expect("Received argument of incorrect type"))
            }
        })
        .collect();

    let has_self_param = matches!(inputs.first(), Some(FnArg::Receiver(_)));

    if !has_self_param {
        return Err(InvalidInput::NoReceiverArgument(method.sig.span()));
    }

    let match_arms: Result<Vec<_>, _> = context
        .parsed_enum
        .variants()
        .iter()
        .map(|variant| {
            let variant_name = &variant.name;
            let target = format_ident!("target");

            Ok(quote! {
                #enum_path::#variant_name(#target) => {
                    #trait_path::#method_name( #target, #(#arguments_to_pass,)* ).into()
                }
            })
        })
        .collect();
    let match_arms = match_arms?;

    let signature = &method.sig;

    Ok(quote! {
        #signature {
            match self {
                #(#match_arms)*
            }
        }
    })
}