udled-macros 0.1.0

Macros for audiens
Documentation
use darling::ast::NestedMeta;
use darling::{Error, FromMeta};
use proc_macro::TokenStream;
use proc_macro2::Span;
use proc_macro2::TokenStream as TokenStream2;
use quote::format_ident;
use quote::quote;
use syn::parse_macro_input;
use syn::Fields;
use syn::Ident;
use syn::ItemEnum;
use syn::LifetimeParam;

pub fn visitor(attr: TokenStream, item: TokenStream) -> TokenStream {
    let enum_item = parse_macro_input!(item as ItemEnum);

    let attr_args = match NestedMeta::parse_meta_list(attr.into()) {
        Ok(v) => v,
        Err(e) => {
            return TokenStream::from(Error::from(e).write_errors());
        }
    };

    let args = match OptionParser::from_list(&attr_args) {
        Ok(v) => v,
        Err(e) => {
            return TokenStream::from(e.write_errors());
        }
    };

    let mut output = vec![generate(&enum_item, VisitType::Ref)];

    if args.with_mut {
        output.push(generate(&enum_item, VisitType::Mut));
    }

    if args.with_fold {
        output.push(generate(&enum_item, VisitType::Fold));
    }

    quote!(
        #enum_item

        #(
            #output
        )*
    )
    .into()
}

fn fields_is_tuple(fields: &Fields) -> bool {
    fields.iter().any(|m| m.ident.is_none())
}

#[derive(FromMeta, Debug)]
struct OptionParser {
    #[darling(default)]
    with_mut: bool,
    #[darling(default)]
    with_fold: bool,
}

enum VisitType {
    Ref,
    Mut,
    Fold,
}

impl VisitType {
    fn method_name(
        &self,
        variant: impl std::fmt::Display,
        enum_name: impl std::fmt::Display,
    ) -> Ident {
        match self {
            Self::Fold => format_ident!("fold_{variant}_{enum_name}"),
            Self::Mut => format_ident!("visit_mut_{variant}_{enum_name}"),
            Self::Ref => format_ident!("visit_{variant}_{enum_name}"),
        }
    }

    fn trait_name(&self, name: impl std::fmt::Display) -> Ident {
        match self {
            Self::Fold => format_ident!("{name}Fold"),
            Self::Mut => format_ident!("{name}VisitorMut"),
            Self::Ref => format_ident!("{name}Visitor"),
        }
    }

    fn accept_name(&self) -> Ident {
        match self {
            Self::Fold => format_ident!("fold"),
            Self::Mut => format_ident!("accept_mut"),
            Self::Ref => format_ident!("accept"),
        }
    }

    fn reference(&self) -> Option<TokenStream2> {
        match self {
            Self::Fold => None,
            Self::Mut => Some(quote!(&'ast mut)),
            Self::Ref => Some(quote!(&'ast )),
        }
    }
}

fn generate(enum_item: &ItemEnum, kind: VisitType) -> TokenStream2 {
    let visitor_name = kind.trait_name(&enum_item.ident);

    let reference = kind.reference();

    let enum_name = format_ident!("{}", enum_item.ident.to_string().to_lowercase());

    let methods = enum_item.variants.iter().map(|variant| {
        let method_name = kind.method_name(variant.ident.to_string().to_lowercase(), &enum_name);

        let is_tuple = fields_is_tuple(&variant.fields);

        let fields = variant.fields.iter().map(|field| {
            let ty = &field.ty;

            if let Some(name) = &field.ident {
                quote!(
                    #name: #reference #ty
                )
            } else {
                quote!(
                    #ty
                )
            }
        });

        let fields = if is_tuple {
            if variant.fields.len() == 1 {
                quote!(
                    member: #reference #(#fields),*
                )
            } else {
                quote!(
                    member: (#(#reference #fields),*)
                )
            }
        } else {
            quote!(#(#fields),*)
        };

        quote!(
            fn #method_name(&mut self, #fields) -> Self::Output;
        )
    });

    let name = &enum_item.ident;

    let accept = enum_item.variants.iter().map(|variant| {
        let name = &variant.ident;

        let tuple = fields_is_tuple(&variant.fields);

        let fields = variant
            .fields
            .iter()
            .enumerate()
            .map(|(idx, field)| {
                if let Some(name) = &field.ident {
                    quote!(
                        #name
                    )
                } else {
                    let name = format_ident!("field_{}", idx);

                    quote!(
                        #name
                    )
                }
            })
            .collect::<Vec<_>>();

        let method_name = kind.method_name(variant.ident.to_string().to_lowercase(), &enum_name);

        if tuple {
            if fields.len() == 1 {
                quote!(
                    Self::#name(#(#fields),*) => visitor.#method_name(#(#fields),*)
                )
            } else {
                quote!(
                    Self::#name(#(#fields),*) => visitor.#method_name((#(#fields),*))
                )
            }
        } else {
            quote!(
                Self::#name { #(#fields),* } => visitor.#method_name(#(#fields),*)
            )
        }
    });

    let accept_method = kind.accept_name();

    let mut generics = enum_item.generics.clone();

    let enum_method = if reference.is_some() {
        generics
            .params
            .push(syn::GenericParam::Lifetime(LifetimeParam::new(
                syn::Lifetime::new("'ast", Span::call_site()),
            )));

        let (_generics_impl, generics_type, where_clause) = &generics.split_for_impl();

        quote!(
            pub fn #accept_method<'ast, V: #visitor_name #generics_type>(#reference self, visitor: &mut V) -> V::Output #where_clause {
                match self {
                    #(#accept),*
                }
            }
        )
    } else {
        let (_generics_impl, generics_type, where_clause) = &generics.split_for_impl();

        quote!(
                pub fn #accept_method<V: #visitor_name #generics_type>(#reference self, visitor: &mut V) -> V::Output #where_clause {
                match self {
                    #(#accept),*
                }
            }
        )
    };

    let (_generics_impl, generics_type, where_clause) = &generics.split_for_impl();

    let (enum_generics_impl, enum_generics_type, enum_where_clause) =
        &enum_item.generics.split_for_impl();

    quote!(

        pub trait #visitor_name #generics_type #where_clause {
            type Output;

            #(#methods)*
        }

        impl #enum_generics_impl  #name #enum_generics_type #enum_where_clause {
            #enum_method
        }


    )
}