use std::ops::Deref;
use proc_macro2::Span;
use syn::parse_quote;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::token;
use syn::token::Comma;
use syn::Arm;
use syn::Field;
use syn::FnArg;
use syn::Ident;
use syn::Pat;
use syn::Signature;
use syn::TraitItemMethod;
use quote::ToTokens;
#[derive(Debug)]
pub struct VariantSig<'info> {
enum_ident: &'info Ident,
variant_ident: &'info Ident,
caller: Ident,
params: Composite,
}
#[derive(Debug)]
pub enum Position<'a> {
Index(usize, &'a Field),
Key(&'a Ident),
}
#[derive(Debug)]
pub enum Param {
Ident(Ident),
Placeholder,
Rest,
}
#[derive(Debug)]
pub enum Composite {
Named(Punctuated<Param, Comma>, token::Brace),
Unnamed(Punctuated<Param, Comma>, token::Paren),
}
impl<'a> Position<'a> {
pub fn from_field(field: &'a Field, fallback: usize) -> Self {
field
.ident
.as_ref()
.map(Position::Key)
.unwrap_or(Position::Index(fallback, field))
}
pub fn get_caller(&self) -> Ident {
match self {
Position::Index(_, _) => parse_quote! {val},
Position::Key(key) => parse_quote! {#key},
}
}
}
impl<'info> VariantSig<'info> {
pub fn new(
enum_ident: &'info Ident,
variant_ident: &'info Ident,
field: &Field,
field_index: usize,
max_length: usize,
) -> Self {
let position = Position::from_field(field, field_index);
let caller = position.get_caller();
let fields = position.format_fields_pattern(max_length);
Self {
enum_ident,
variant_ident,
caller,
params: fields,
}
}
pub fn parse_arm(&'info self, method: &'info TraitItemMethod) -> (&Ident, Arm) {
let Self {
enum_ident,
variant_ident,
caller,
params: fields,
..
} = self;
let (method_ident, sanitized_input) = get_method_parts(method);
(
method_ident,
parse_quote! {#enum_ident :: #variant_ident #fields => #caller . #method_ident (#sanitized_input)},
)
}
}
impl<'a> Position<'a> {
pub fn format_fields_pattern(&self, arity: usize) -> Composite {
let mut punc = Punctuated::<Param, Comma>::new();
match self {
Position::Index(index, field) => {
for _ in 1..*index + 1 {
punc.push_value(Param::Placeholder);
punc.push_punct(Comma(field.span()));
}
punc.push_value(Param::Ident(Ident::new("val", field.span())));
if arity > index + 1 {
punc.push_punct(Comma(field.span()));
punc.push_value(Param::Rest);
}
Composite::Unnamed(punc, token::Paren(field.span()))
}
Position::Key(key) => {
punc.push_value(Param::Ident((*key).clone()));
if arity > 1 {
punc.push_punct(Comma(key.span()));
punc.push_value(Param::Rest);
}
Composite::Named(punc, token::Brace(key.span()))
}
}
}
}
impl ToTokens for Composite {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
match self {
Composite::Named(param, brace) => {
brace.surround(tokens, |tokens| param.to_tokens(tokens))
}
Composite::Unnamed(param, paren) => {
paren.surround(tokens, |tokens| param.to_tokens(tokens))
}
}
}
}
impl ToTokens for Param {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
match self {
Param::Ident(ident) => ident.to_tokens(tokens),
Param::Placeholder => token::Underscore(Span::call_site()).to_tokens(tokens),
Param::Rest => token::Dot2(Span::call_site()).to_tokens(tokens),
}
}
}
fn sanitize(inputs: &Punctuated<FnArg, Comma>) -> Punctuated<Pat, Comma> {
let mut san = Punctuated::new();
let max = inputs.len();
inputs.iter().enumerate().for_each(|(i, arg)| match arg {
syn::FnArg::Receiver(_) => (),
syn::FnArg::Typed(typed) => {
san.push_value(typed.pat.deref().clone());
if i != max - 1 {
san.push_punct(Comma(Span::call_site()));
}
}
});
san
}
fn get_method_parts(method: &TraitItemMethod) -> (&Ident, Punctuated<Pat, Comma>) {
let TraitItemMethod { sig, .. } = method;
let Signature { ident, inputs, .. } = sig;
(ident, sanitize(inputs))
}