typeshift-derive 0.5.1

Proc-macro derive support for typeshift
Documentation
//! Proc macros for `typeshift`.
//!
//! `#[typeshift]` is the primary entry point. It augments a struct/enum with
//! derives and helper attributes required by `serde`, `validator`, and
//! `schemars`.

use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{Attribute, Fields, Item, ItemEnum, parse_macro_input};

#[proc_macro_attribute]
pub fn typeshift(_attr: TokenStream, item: TokenStream) -> TokenStream {
    let mut item = parse_macro_input!(item as Item);

    match &mut item {
        Item::Struct(input) => {
            apply_typeshift_attrs(&mut input.attrs, true);
            quote!(#input).into()
        }
        Item::Enum(input) => {
            apply_typeshift_attrs(&mut input.attrs, false);

            let validate_impl = build_enum_validate_impl(input);

            quote! {
                #input
                #validate_impl
            }
            .into()
        }
        _ => syn::Error::new_spanned(item, "#[typeshift] supports structs and enums only")
            .to_compile_error()
            .into(),
    }
}

fn build_enum_validate_impl(input: &ItemEnum) -> proc_macro2::TokenStream {
    if has_derived_trait(&input.attrs, "Validate") {
        return quote! {};
    }

    let ident = &input.ident;
    let generics = &input.generics;
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
    let helper_generics_def = helper_def_generics(generics);
    let helper_generics_use = helper_use_generics(generics);

    let helper_defs = input.variants.iter().filter_map(|variant| {
        let variant_ident = &variant.ident;
        let helper_ident = format_ident!("__TypeShiftValidate{}{}", ident, variant_ident);
        match &variant.fields {
            Fields::Unit => None,
            Fields::Named(fields) => {
                let defs = fields.named.iter().map(|field| {
                    let attrs = validate_attrs(&field.attrs);
                    let name = match &field.ident {
                        Some(name) => name,
                        None => unreachable!("named field must have ident"),
                    };
                    let ty = &field.ty;
                    quote! { #(#attrs)* #name: &'__typeshift_enum_validate #ty }
                });

                Some(quote! {
                    #[allow(dead_code)]
                    #[derive(::typeshift::validator::Validate)]
                    #[validate(crate = "typeshift::validator")]
                    struct #helper_ident #helper_generics_def #where_clause {
                        #(#defs,)*
                    }
                })
            }
            Fields::Unnamed(fields) => {
                let defs = fields.unnamed.iter().enumerate().map(|(idx, field)| {
                    let attrs = validate_attrs(&field.attrs);
                    let name = format_ident!("__field_{idx}");
                    let ty = &field.ty;
                    quote! { #(#attrs)* #name: &'__typeshift_enum_validate #ty }
                });

                Some(quote! {
                    #[allow(dead_code)]
                    #[derive(::typeshift::validator::Validate)]
                    #[validate(crate = "typeshift::validator")]
                    struct #helper_ident #helper_generics_def #where_clause {
                        #(#defs,)*
                    }
                })
            }
        }
    });

    let arms = input.variants.iter().map(|variant| {
        let variant_ident = &variant.ident;
        let helper_ident = format_ident!("__TypeShiftValidate{}{}", ident, variant_ident);
        match &variant.fields {
            Fields::Unit => {
                quote! {
                    Self::#variant_ident => ::core::result::Result::Ok(())
                }
            }
            Fields::Named(fields) => {
                let names: Vec<_> = fields
                    .named
                    .iter()
                    .filter_map(|field| field.ident.as_ref())
                    .collect();
                quote! {
                    Self::#variant_ident { #(#names,)* } => {
                        let helper = #helper_ident #helper_generics_use { #(#names,)* };
                        ::typeshift::validator::Validate::validate(&helper)
                    }
                }
            }
            Fields::Unnamed(fields) => {
                let bindings: Vec<_> = fields
                    .unnamed
                    .iter()
                    .enumerate()
                    .map(|(idx, _)| format_ident!("__field_{idx}"))
                    .collect();
                let init_fields = bindings.iter().map(|name| quote! { #name: #name });
                quote! {
                    Self::#variant_ident( #(#bindings,)* ) => {
                        let helper = #helper_ident #helper_generics_use { #(#init_fields,)* };
                        ::typeshift::validator::Validate::validate(&helper)
                    }
                }
            }
        }
    });

    quote! {
        #(#helper_defs)*

        impl #impl_generics ::typeshift::validator::Validate for #ident #ty_generics #where_clause {
            fn validate(&self) -> ::core::result::Result<(), ::typeshift::validator::ValidationErrors> {
                match self {
                    #(#arms,)*
                }
            }
        }
    }
}

fn helper_def_generics(generics: &syn::Generics) -> proc_macro2::TokenStream {
    let params = &generics.params;
    if params.is_empty() {
        quote! { <'__typeshift_enum_validate> }
    } else {
        quote! { <'__typeshift_enum_validate, #params> }
    }
}

fn helper_use_generics(generics: &syn::Generics) -> proc_macro2::TokenStream {
    let args: Vec<proc_macro2::TokenStream> = generics
        .params
        .iter()
        .map(|param| match param {
            syn::GenericParam::Type(ty) => {
                let ident = &ty.ident;
                quote! { #ident }
            }
            syn::GenericParam::Lifetime(lt) => {
                let lifetime = &lt.lifetime;
                quote! { #lifetime }
            }
            syn::GenericParam::Const(konst) => {
                let ident = &konst.ident;
                quote! { #ident }
            }
        })
        .collect();

    if args.is_empty() {
        quote! { ::<'_> }
    } else {
        quote! { ::<'_, #(#args,)*> }
    }
}

fn validate_attrs(attrs: &[Attribute]) -> Vec<Attribute> {
    attrs
        .iter()
        .filter(|attr| attr.path().is_ident("validate"))
        .cloned()
        .collect()
}

#[proc_macro_derive(TypeShift, attributes(validate, serde, schemars))]
/// Legacy compatibility derive.
///
/// This derive intentionally generates no code. Use `#[typeshift]` as the
/// primary macro entry point.
pub fn derive_typeshift(_input: TokenStream) -> TokenStream {
    TokenStream::new()
}

fn apply_typeshift_attrs(attrs: &mut Vec<Attribute>, include_validate: bool) {
    let mut required = vec!["Serialize", "Deserialize", "JsonSchema"];
    if include_validate {
        required.push("Validate");
    }
    add_missing_derives(attrs, &required);
    ensure_attr(attrs, "serde", "crate = \"typeshift::serde\"");
    ensure_attr(attrs, "schemars", "crate = \"typeshift::schemars\"");
    if include_validate {
        ensure_attr(attrs, "validate", "crate = \"typeshift::validator\"");
    }
}

fn has_derived_trait(attrs: &[Attribute], trait_name: &str) -> bool {
    attrs
        .iter()
        .filter(|attr| attr.path().is_ident("derive"))
        .filter_map(|attr| {
            attr.parse_args_with(
                syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated,
            )
            .ok()
        })
        .flat_map(|paths| paths.into_iter())
        .any(|path| {
            path.segments
                .last()
                .map(|seg| seg.ident == trait_name)
                .unwrap_or(false)
        })
}

fn add_missing_derives(attrs: &mut Vec<Attribute>, required: &[&str]) {
    let mut missing = Vec::new();
    for name in required {
        if has_derived_trait(attrs, name) {
            continue;
        }
        let path: syn::Path = match *name {
            "Serialize" => syn::parse_quote!(::typeshift::serde::Serialize),
            "Deserialize" => syn::parse_quote!(::typeshift::serde::Deserialize),
            "Validate" => syn::parse_quote!(::typeshift::validator::Validate),
            "JsonSchema" => syn::parse_quote!(::typeshift::schemars::JsonSchema),
            _ => continue,
        };
        missing.push(path);
    }

    if !missing.is_empty() {
        let insert_at = attrs
            .iter()
            .rposition(|attr| attr.path().is_ident("derive"))
            .map(|index| index + 1)
            .unwrap_or(0);
        attrs.insert(insert_at, syn::parse_quote!(#[derive(#(#missing),*)]));
    }
}

fn ensure_attr(attrs: &mut Vec<Attribute>, name: &str, args: &str) {
    let path = syn::Ident::new(name, proc_macro2::Span::call_site());
    let args: proc_macro2::TokenStream = match args.parse() {
        Ok(args) => args,
        Err(_) => return,
    };

    let has_crate_arg = attrs
        .iter()
        .any(|attr| attr.path().is_ident(name) && attr_has_crate_arg(attr));

    if !has_crate_arg {
        attrs.push(syn::parse_quote!(#[#path(#args)]));
    }
}

fn attr_has_crate_arg(attr: &Attribute) -> bool {
    attr.parse_args_with(syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated)
        .map(|metas| {
            metas.into_iter().any(|meta| {
                if let syn::Meta::NameValue(name_value) = meta {
                    return name_value.path.is_ident("crate");
                }
                false
            })
        })
        .unwrap_or(false)
}