moxy-derive 0.0.4

derive macros for moxy crate
Documentation
use proc_macro2::TokenStream;
use quote::{format_ident, quote};

use crate::{
    Render,
    core::{Attrs, Field, pascal_to_snake},
    params,
};

use super::syntax::VariantSyntax;

#[derive(Clone, Default)]
pub struct EnumSyntax;

impl Render for EnumSyntax {
    type Args = params::EnumParams;

    fn render(&self, args: Self::Args) -> syn::Result<TokenStream> {
        let ident = &args.input.ident;
        let (impl_generics, type_generics, where_generics) = args.input.generics.split_for_impl();
        let mut variant_methods = Vec::new();
        let mut named_variants: Vec<(&syn::Ident, Vec<Field>)> = Vec::new();
        let mut skipped_variants: Vec<&syn::Ident> = Vec::new();

        for variant in &args.data.variants {
            let variant_ident = &variant.ident;
            let attrs = Attrs::parse(&variant.attrs)?;
            let opts = VariantSyntax::parse(&attrs)?;

            if opts.skip {
                skipped_variants.push(variant_ident);
                continue;
            }

            let base_name = opts
                .alias
                .unwrap_or_else(|| pascal_to_snake(&variant_ident.to_string()));

            let is_name = format_ident!("is_{}", base_name);
            let as_name = format_ident!("as_{}", base_name);
            let to_name = format_ident!("to_{}", base_name);

            let fields: Vec<_> = variant
                .fields
                .iter()
                .enumerate()
                .map(|(i, f)| Field::parse(i, f))
                .collect::<syn::Result<Vec<_>>>()?;

            match &variant.fields {
                syn::Fields::Unit => {
                    variant_methods.push(quote! {
                        pub fn #is_name(&self) -> bool {
                            matches!(self, Self::#variant_ident)
                        }
                    });
                }
                syn::Fields::Unnamed(_) => {
                    let bindings: Vec<_> = fields
                        .iter()
                        .enumerate()
                        .map(|(i, _)| format_ident!("__v{}", i))
                        .collect();

                    let ref_types: Vec<_> = fields.iter().map(|f| f.ty()).collect();

                    let is_pattern = if bindings.is_empty() {
                        quote! { Self::#variant_ident() }
                    } else {
                        quote! { Self::#variant_ident(..) }
                    };

                    variant_methods.push(quote! {
                        pub fn #is_name(&self) -> bool {
                            matches!(self, #is_pattern)
                        }
                    });

                    if !bindings.is_empty() {
                        let pattern = quote! { Self::#variant_ident(#(#bindings),*) };

                        if bindings.len() == 1 {
                            let b = &bindings[0];
                            let ty = ref_types[0];

                            variant_methods.push(quote! {
                                pub fn #as_name(&self) -> ::std::option::Option<&#ty> {
                                    match self {
                                        #pattern => ::std::option::Option::Some(#b),
                                        _ => ::std::option::Option::None,
                                    }
                                }
                            });

                            variant_methods.push(quote! {
                                pub fn #to_name(&self) -> ::std::option::Option<#ty>
                                where #ty: ::std::clone::Clone
                                {
                                    match self {
                                        #pattern => ::std::option::Option::Some(#b.clone()),
                                        _ => ::std::option::Option::None,
                                    }
                                }
                            });
                        } else {
                            let ref_tuple: Vec<_> = ref_types.iter().map(|t| quote!(&#t)).collect();
                            let ref_vals: Vec<_> = bindings.iter().map(|b| quote!(#b)).collect();
                            let clone_vals: Vec<_> =
                                bindings.iter().map(|b| quote!(#b.clone())).collect();

                            variant_methods.push(quote! {
                                pub fn #as_name(&self) -> ::std::option::Option<(#(#ref_tuple),*)> {
                                    match self {
                                        #pattern => ::std::option::Option::Some((#(#ref_vals),*)),
                                        _ => ::std::option::Option::None,
                                    }
                                }
                            });

                            variant_methods.push(quote! {
                                pub fn #to_name(&self) -> ::std::option::Option<(#(#ref_types),*)>
                                where #(#ref_types: ::std::clone::Clone),*
                                {
                                    match self {
                                        #pattern => ::std::option::Option::Some((#(#clone_vals),*)),
                                        _ => ::std::option::Option::None,
                                    }
                                }
                            });
                        }
                    }
                }
                syn::Fields::Named(_) => {
                    named_variants.push((variant_ident, fields.clone()));

                    let field_idents: Vec<_> = fields
                        .iter()
                        .filter_map(|f| match f.name() {
                            crate::core::FieldName::Ident(id) => Some(id.clone()),
                            _ => None,
                        })
                        .collect();

                    let ref_types: Vec<_> = fields.iter().map(|f| f.ty()).collect();

                    variant_methods.push(quote! {
                        pub fn #is_name(&self) -> bool {
                            matches!(self, Self::#variant_ident { .. })
                        }
                    });

                    if !field_idents.is_empty() {
                        let pattern = quote! { Self::#variant_ident { #(#field_idents),*, .. } };

                        if field_idents.len() == 1 {
                            let f = &field_idents[0];
                            let ty = ref_types[0];

                            variant_methods.push(quote! {
                                pub fn #as_name(&self) -> ::std::option::Option<&#ty> {
                                    match self {
                                        #pattern => ::std::option::Option::Some(#f),
                                        _ => ::std::option::Option::None,
                                    }
                                }
                            });

                            variant_methods.push(quote! {
                                pub fn #to_name(&self) -> ::std::option::Option<#ty>
                                where #ty: ::std::clone::Clone
                                {
                                    match self {
                                        #pattern => ::std::option::Option::Some(#f.clone()),
                                        _ => ::std::option::Option::None,
                                    }
                                }
                            });
                        } else {
                            let ref_tuple: Vec<_> = ref_types.iter().map(|t| quote!(&#t)).collect();
                            let ref_vals: Vec<_> =
                                field_idents.iter().map(|f| quote!(#f)).collect();
                            let clone_vals: Vec<_> =
                                field_idents.iter().map(|f| quote!(#f.clone())).collect();

                            variant_methods.push(quote! {
                                pub fn #as_name(&self) -> ::std::option::Option<(#(#ref_tuple),*)> {
                                    match self {
                                        #pattern => ::std::option::Option::Some((#(#ref_vals),*)),
                                        _ => ::std::option::Option::None,
                                    }
                                }
                            });

                            variant_methods.push(quote! {
                                pub fn #to_name(&self) -> ::std::option::Option<(#(#ref_types),*)>
                                where #(#ref_types: ::std::clone::Clone),*
                                {
                                    match self {
                                        #pattern => ::std::option::Option::Some((#(#clone_vals),*)),
                                        _ => ::std::option::Option::None,
                                    }
                                }
                            });
                        }
                    }
                }
            }
        }

        let common_methods = render_common_fields(ident, &named_variants, &skipped_variants);

        Ok(quote! {
            impl #impl_generics #ident #type_generics #where_generics {
                #(#variant_methods)*
                #common_methods
            }
        })
    }
}

fn render_common_fields(
    enum_ident: &syn::Ident,
    named_variants: &[(&syn::Ident, Vec<Field>)],
    _skipped: &[&syn::Ident],
) -> TokenStream {
    if named_variants.len() < 2 {
        return quote!();
    }

    let first_fields = &named_variants[0].1;
    let mut common: Vec<(&syn::Ident, &syn::Type)> = Vec::new();

    for field in first_fields {
        let crate::core::FieldName::Ident(name) = field.name() else {
            continue;
        };

        let ty = field.ty();
        let ty_str = quote!(#ty).to_string();

        let in_all = named_variants[1..].iter().all(|(_, fields)| {
            fields.iter().any(|f| {
                let crate::core::FieldName::Ident(n) = f.name() else {
                    return false;
                };
                let fty = f.ty();
                n == name && quote!(#fty).to_string() == ty_str
            })
        });

        if in_all {
            common.push((name, ty));
        }
    }

    let mut methods = Vec::new();

    for (field_name, field_ty) in &common {
        let arms: Vec<_> = named_variants
            .iter()
            .map(|(variant_ident, _)| {
                quote! { Self::#variant_ident { #field_name, .. } => #field_name }
            })
            .collect();

        let all_variant_idents: Vec<_> = named_variants.iter().map(|(v, _)| *v).collect();
        let _ = &all_variant_idents;
        let _ = &enum_ident;

        methods.push(quote! {
            pub fn #field_name(&self) -> &#field_ty {
                match self {
                    #(#arms,)*
                    _ => unreachable!(),
                }
            }
        });
    }

    quote!(#(#methods)*)
}