use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Fields, Ident, ItemStruct};
pub fn type_state_inner(args: TokenStream, input: TokenStream) -> TokenStream {
let input_args: Vec<_> = args.into_iter().collect();
let state_slots: usize = if let Some(proc_macro::TokenTree::Literal(lit)) = input_args.get(2) {
lit.to_string().parse().unwrap()
} else {
panic!("Expected a valid number for state_slots.");
};
let default_state: Ident = if let Some(proc_macro::TokenTree::Ident(ident)) = input_args.get(6)
{
Ident::new(&format!("{}", ident), ident.span().into())
} else {
panic!("Expected an identifier for default_state.");
};
let input_struct = parse_macro_input!(input as ItemStruct);
let struct_name = &input_struct.ident;
let generics = &input_struct.generics;
let visibility = &input_struct.vis;
let struct_fields = match input_struct.fields {
Fields::Named(ref fields) => &fields.named,
Fields::Unnamed(_) => panic!("Expected named fields in struct."),
Fields::Unit => panic!("Expected a struct with fields."),
};
let state_idents: Vec<Ident> = (0..state_slots)
.map(|i| Ident::new(&format!("State{}", i + 1), struct_name.span()))
.collect();
let default_generics = vec![quote!(#default_state); state_slots];
let original_generics = generics.params.iter();
let combined_generics = if generics.params.is_empty() {
quote! {
#(#state_idents = #default_generics),*
}
} else {
quote! {
#(#original_generics),*, #(#state_idents = #default_generics),*
}
};
let where_clauses: Vec<proc_macro2::TokenStream> = (0..state_slots)
.map(|i| {
let state_num = Ident::new(&format!("State{}", i + 1), struct_name.span());
quote!(#state_num: TypeStateProtector)
})
.collect();
let merged_where_clause = if let Some(existing_where) = &generics.where_clause {
quote! {
#existing_where #(#where_clauses),*
}
} else if !where_clauses.is_empty() {
quote! {
where #(#where_clauses),*
}
} else {
quote! {}
};
let phantom_fields = state_idents
.iter()
.map(|ident| quote!(::std::marker::PhantomData<fn() -> #ident>))
.collect::<Vec<_>>();
let output = quote! {
#[allow(clippy::type_complexity)]
#visibility struct #struct_name<#combined_generics>
#merged_where_clause
{
#struct_fields
_state: (#(#phantom_fields),*),
}
};
output.into()
}