use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::spanned::Spanned;
use syn::{FnArg, ItemTrait, Signature, TraitItem, TraitItemMethod, TraitItemType};
use crate::error::InvalidInput;
use crate::generate_delegation::MacroContext;
use crate::input::AssociatedTypeUnification;
pub(crate) fn generate_trait_impl(context: &MacroContext) -> Result<TokenStream, InvalidInput> {
let MacroContext {
trait_path,
parsed_trait,
enum_path,
..
} = *context;
validate_trait(parsed_trait)?;
let item_implementations: Result<Vec<_>, _> = parsed_trait
.items
.iter()
.map(|item| match item {
TraitItem::Method(method) => generate_fn_implementation(context, method),
TraitItem::Type(type_) => Ok(generate_type_implementation(context, type_)),
other => Err(InvalidInput::UnsupportedTraitItem(other.span())),
})
.collect();
let item_implementations = item_implementations?;
let associated_type_equality_guard = generate_associated_type_equality_guards(context);
Ok(quote! {
impl #trait_path for #enum_path where #associated_type_equality_guard {
#(#item_implementations)*
}
})
}
fn validate_trait(trait_: &ItemTrait) -> Result<(), InvalidInput> {
if !trait_.generics.params.is_empty() {
return Err(InvalidInput::UnsupportedFeature(
trait_.generics.span(),
"generics",
));
}
if !trait_.supertraits.is_empty() {
return Err(InvalidInput::UnsupportedFeature(
trait_.supertraits.span(),
"supertraits",
));
}
Ok(())
}
fn generate_associated_type_equality_guards(context: &MacroContext) -> TokenStream {
if context.parsed_enum.variants().len() < 2 {
return TokenStream::new();
}
let trait_name = context.trait_path;
let helper_mod = &context.helper_mod_name;
let types_needing_guard = context
.associated_type_config
.iter()
.filter(|(_, value)| matches!(value, AssociatedTypeUnification::Same))
.map(|(type_ident, _)| type_ident);
let guards = types_needing_guard.flat_map(|type_ident| {
let mut variants = context.parsed_enum.variants().iter();
let first_variant_type = &variants.next().unwrap().type_;
let remaining_variants = variants;
remaining_variants.map(move |variant| {
let other_variant_type = &variant.type_;
quote! {
(
<#first_variant_type as #trait_name>::#type_ident,
<#other_variant_type as #trait_name>::#type_ident
): #helper_mod::EqualTypes
}
})
});
quote! {
#(#guards,)*
}
}
fn generate_type_implementation(context: &MacroContext, type_: &TraitItemType) -> TokenStream {
let config = context
.associated_type_config
.get(&type_.ident)
.expect("All associated types should have a parsed type configuration.");
match config {
AssociatedTypeUnification::Same => redeclare_first_variant_associated_type(context, type_),
AssociatedTypeUnification::EnumWrap => {
let corresponding_enum = context
.associated_type_enum_names
.get(&type_.ident)
.expect("All EnumWrap types should have an entry in associated_type_enum_names");
let type_ident = &type_.ident;
quote!(
type #type_ident = #corresponding_enum;
)
}
}
}
fn redeclare_first_variant_associated_type(
context: &MacroContext,
type_: &TraitItemType,
) -> TokenStream {
let type_ident = &type_.ident;
let trait_path = context.trait_path;
let variant_type = &context.parsed_enum.first_variant().type_;
quote!(
type #type_ident = <#variant_type as #trait_path>::#type_ident;
)
}
fn generate_fn_implementation(
context: &MacroContext,
method: &TraitItemMethod,
) -> Result<TokenStream, InvalidInput> {
let MacroContext {
trait_path,
enum_path,
..
} = *context;
let Signature {
ident: method_name,
inputs,
..
} = &method.sig;
let arguments_to_pass: Vec<_> = inputs
.iter()
.skip(1)
.map(|input| match input {
FnArg::Receiver(_) => {
unreachable!("only first argument, which we skipped, can be receiver")
}
FnArg::Typed(t) => {
let argument = &t.pat;
quote!(#argument.try_into().expect("Received argument of incorrect type"))
}
})
.collect();
let has_self_param = matches!(inputs.first(), Some(FnArg::Receiver(_)));
if !has_self_param {
return Err(InvalidInput::NoReceiverArgument(method.sig.span()));
}
let match_arms: Result<Vec<_>, _> = context
.parsed_enum
.variants()
.iter()
.map(|variant| {
let variant_name = &variant.name;
let target = format_ident!("target");
Ok(quote! {
#enum_path::#variant_name(#target) => {
#trait_path::#method_name( #target, #(#arguments_to_pass,)* ).into()
}
})
})
.collect();
let match_arms = match_arms?;
let signature = &method.sig;
Ok(quote! {
#signature {
match self {
#(#match_arms)*
}
}
})
}