use proc_macro2::TokenStream as TokenStream2;
use quote::{quote, ToTokens};
use std::collections::HashSet;
use syn::{Fields, Generics, Ident, Visibility};
use crate::enum_parser::{ParsedMethod, ParsedVariant};
use crate::helpers::{
add_static_bounds, merge_generics, strip_pattern_generics, substitute_type_params,
};
use crate::type_analysis::{
collect_all_type_param_names, collect_variant_type_params, extract_trait_type_from_attrs,
};
fn extract_type_params_from_trait(
trait_type: &TokenStream2,
all_type_params: &HashSet<String>,
) -> HashSet<String> {
use proc_macro2::TokenTree;
let mut used_params = HashSet::new();
for token in trait_type.clone() {
match token {
TokenTree::Ident(ident) => {
let ident_str = ident.to_string();
if all_type_params.contains(&ident_str) {
used_params.insert(ident_str);
}
}
TokenTree::Group(group) => {
used_params.extend(extract_type_params_from_trait(
&group.stream(),
all_type_params,
));
}
_ => {}
}
}
used_params
}
pub fn generate_variant_struct(
variant_name: &Ident,
variant_generics: &Generics,
fields: &Fields,
vis: &Visibility,
) -> TokenStream2 {
match fields {
Fields::Named(fields) => quote! {
#vis struct #variant_name #variant_generics #fields
},
Fields::Unnamed(fields) => quote! {
#vis struct #variant_name #variant_generics #fields;
},
Fields::Unit => quote! {
#vis struct #variant_name #variant_generics;
},
}
}
pub fn generate_method_body(
variant: &ParsedVariant,
method: &ParsedMethod,
variant_ty_generics: &TokenStream2,
trait_type: &TokenStream2,
all_type_params_ordered: &[String],
) -> Option<(TokenStream2, bool)> {
let variant_name = &variant.ident;
let variant_name_str = variant_name.to_string();
let matching_arms: Vec<_> = method
.arms
.iter()
.filter(|arm| {
let pattern_string = arm.pattern.to_string();
pattern_string.contains(&variant_name_str)
})
.collect();
if matching_arms.is_empty() {
return None;
}
let arm = matching_arms[0];
let body = &arm.body;
let pattern_raw = &arm.pattern;
let cleaned_pattern = strip_pattern_generics(pattern_raw);
let sig_str = method.sig.to_string();
let new_sig_str = substitute_type_params(&sig_str, trait_type, all_type_params_ordered);
let new_sig: TokenStream2 = new_sig_str.parse().unwrap_or_else(|_| method.sig.clone());
let is_boxed_self =
sig_str.contains("self : Box < Self >") || sig_str.contains("self: Box<Self>");
let match_expr = if is_boxed_self {
quote! {
let __concrete_box = (self as Box<dyn std::any::Any>)
.downcast::<#variant_name #variant_ty_generics>()
.expect("Downcast failed");
match *__concrete_box {
#cleaned_pattern => #body,
_ => unreachable!(),
}
}
} else {
quote! {
match self {
#cleaned_pattern => #body,
_ => unreachable!(),
}
}
};
let method_impl = quote! {
#new_sig {
#match_expr
}
};
Some((method_impl, is_boxed_self))
}
pub fn generate_combined_trait_impl(
variant: &ParsedVariant,
methods: &[ParsedMethod],
impl_generics: &Generics,
variant_ty_generics: &TokenStream2,
where_clause: &TokenStream2,
trait_type: &TokenStream2,
all_type_params_ordered: &[String],
) -> TokenStream2 {
let variant_name = &variant.ident;
let (impl_generics_tokens, _, _) = impl_generics.split_for_impl();
let method_impls: Vec<_> = methods
.iter()
.filter_map(|method| {
generate_method_body(
variant,
method,
variant_ty_generics,
trait_type,
all_type_params_ordered,
)
.map(|(method_impl, _)| method_impl)
})
.collect();
if method_impls.is_empty() {
quote! {
impl #impl_generics_tokens #trait_type
for #variant_name #variant_ty_generics #where_clause {}
}
} else {
quote! {
impl #impl_generics_tokens #trait_type
for #variant_name #variant_ty_generics #where_clause {
#(#method_impls)*
}
}
}
}
pub fn generate_variant_code(
variant: &ParsedVariant,
methods: &[ParsedMethod],
generics_with_static: &Generics,
all_type_params: &HashSet<String>,
all_type_params_ordered: &[String],
vis: &Visibility,
enum_name: &Ident,
) -> TokenStream2 {
let variant_name = &variant.ident;
let variant_generics_with_static = add_static_bounds(&variant.generics);
let mut combined_type_params = collect_all_type_param_names(&variant_generics_with_static);
combined_type_params.extend(all_type_params.iter().cloned());
let struct_type_params = collect_variant_type_params(&variant.fields, &combined_type_params);
let struct_generics = merge_generics(
&variant_generics_with_static,
generics_with_static,
&struct_type_params,
);
let (_struct_impl_generics, variant_ty_generics, _struct_where_clause) =
struct_generics.split_for_impl();
let struct_def = generate_variant_struct(variant_name, &struct_generics, &variant.fields, vis);
let trait_type = if let Some(ref tt) = variant.trait_type {
tt.clone()
} else if let Some(tt) = extract_trait_type_from_attrs(&variant.attrs) {
tt
} else {
let ty_generics = generics_with_static.split_for_impl().1;
quote! { #enum_name #ty_generics }
};
let trait_type_params = extract_type_params_from_trait(&trait_type, all_type_params);
let mut impl_type_params = struct_type_params.clone();
impl_type_params.extend(trait_type_params);
let impl_generics = merge_generics(
&variant_generics_with_static,
generics_with_static,
&impl_type_params,
);
let (_impl_generics_tokens, _, where_clause_impl) = impl_generics.split_for_impl();
let trait_impl = generate_combined_trait_impl(
variant,
methods,
&impl_generics,
&variant_ty_generics.to_token_stream(),
&where_clause_impl.to_token_stream(),
&trait_type,
all_type_params_ordered,
);
quote! {
#struct_def
#trait_impl
}
}