use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{Attribute, Fields, Item, ItemEnum, parse_macro_input};
#[proc_macro_attribute]
pub fn typeshift(_attr: TokenStream, item: TokenStream) -> TokenStream {
let mut item = parse_macro_input!(item as Item);
match &mut item {
Item::Struct(input) => {
apply_typeshift_attrs(&mut input.attrs, true);
quote!(#input).into()
}
Item::Enum(input) => {
apply_typeshift_attrs(&mut input.attrs, false);
let validate_impl = build_enum_validate_impl(input);
quote! {
#input
#validate_impl
}
.into()
}
_ => syn::Error::new_spanned(item, "#[typeshift] supports structs and enums only")
.to_compile_error()
.into(),
}
}
fn build_enum_validate_impl(input: &ItemEnum) -> proc_macro2::TokenStream {
if has_derived_trait(&input.attrs, "Validate") {
return quote! {};
}
let ident = &input.ident;
let generics = &input.generics;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let helper_generics_def = helper_def_generics(generics);
let helper_generics_use = helper_use_generics(generics);
let helper_defs = input.variants.iter().filter_map(|variant| {
let variant_ident = &variant.ident;
let helper_ident = format_ident!("__TypeShiftValidate{}{}", ident, variant_ident);
match &variant.fields {
Fields::Unit => None,
Fields::Named(fields) => {
let defs = fields.named.iter().map(|field| {
let attrs = validate_attrs(&field.attrs);
let name = match &field.ident {
Some(name) => name,
None => unreachable!("named field must have ident"),
};
let ty = &field.ty;
quote! { #(#attrs)* #name: &'__typeshift_enum_validate #ty }
});
Some(quote! {
#[allow(dead_code)]
#[derive(::typeshift::validator::Validate)]
#[validate(crate = "typeshift::validator")]
struct #helper_ident #helper_generics_def #where_clause {
#(#defs,)*
}
})
}
Fields::Unnamed(fields) => {
let defs = fields.unnamed.iter().enumerate().map(|(idx, field)| {
let attrs = validate_attrs(&field.attrs);
let name = format_ident!("__field_{idx}");
let ty = &field.ty;
quote! { #(#attrs)* #name: &'__typeshift_enum_validate #ty }
});
Some(quote! {
#[allow(dead_code)]
#[derive(::typeshift::validator::Validate)]
#[validate(crate = "typeshift::validator")]
struct #helper_ident #helper_generics_def #where_clause {
#(#defs,)*
}
})
}
}
});
let arms = input.variants.iter().map(|variant| {
let variant_ident = &variant.ident;
let helper_ident = format_ident!("__TypeShiftValidate{}{}", ident, variant_ident);
match &variant.fields {
Fields::Unit => {
quote! {
Self::#variant_ident => ::core::result::Result::Ok(())
}
}
Fields::Named(fields) => {
let names: Vec<_> = fields
.named
.iter()
.filter_map(|field| field.ident.as_ref())
.collect();
quote! {
Self::#variant_ident { #(#names,)* } => {
let helper = #helper_ident #helper_generics_use { #(#names,)* };
::typeshift::validator::Validate::validate(&helper)
}
}
}
Fields::Unnamed(fields) => {
let bindings: Vec<_> = fields
.unnamed
.iter()
.enumerate()
.map(|(idx, _)| format_ident!("__field_{idx}"))
.collect();
let init_fields = bindings.iter().map(|name| quote! { #name: #name });
quote! {
Self::#variant_ident( #(#bindings,)* ) => {
let helper = #helper_ident #helper_generics_use { #(#init_fields,)* };
::typeshift::validator::Validate::validate(&helper)
}
}
}
}
});
quote! {
#(#helper_defs)*
impl #impl_generics ::typeshift::validator::Validate for #ident #ty_generics #where_clause {
fn validate(&self) -> ::core::result::Result<(), ::typeshift::validator::ValidationErrors> {
match self {
#(#arms,)*
}
}
}
}
}
fn helper_def_generics(generics: &syn::Generics) -> proc_macro2::TokenStream {
let params = &generics.params;
if params.is_empty() {
quote! { <'__typeshift_enum_validate> }
} else {
quote! { <'__typeshift_enum_validate, #params> }
}
}
fn helper_use_generics(generics: &syn::Generics) -> proc_macro2::TokenStream {
let args: Vec<proc_macro2::TokenStream> = generics
.params
.iter()
.map(|param| match param {
syn::GenericParam::Type(ty) => {
let ident = &ty.ident;
quote! { #ident }
}
syn::GenericParam::Lifetime(lt) => {
let lifetime = <.lifetime;
quote! { #lifetime }
}
syn::GenericParam::Const(konst) => {
let ident = &konst.ident;
quote! { #ident }
}
})
.collect();
if args.is_empty() {
quote! { ::<'_> }
} else {
quote! { ::<'_, #(#args,)*> }
}
}
fn validate_attrs(attrs: &[Attribute]) -> Vec<Attribute> {
attrs
.iter()
.filter(|attr| attr.path().is_ident("validate"))
.cloned()
.collect()
}
#[proc_macro_derive(TypeShift, attributes(validate, serde, schemars))]
pub fn derive_typeshift(_input: TokenStream) -> TokenStream {
TokenStream::new()
}
fn apply_typeshift_attrs(attrs: &mut Vec<Attribute>, include_validate: bool) {
let mut required = vec!["Serialize", "Deserialize", "JsonSchema"];
if include_validate {
required.push("Validate");
}
add_missing_derives(attrs, &required);
ensure_attr(attrs, "serde", "crate = \"typeshift::serde\"");
ensure_attr(attrs, "schemars", "crate = \"typeshift::schemars\"");
if include_validate {
ensure_attr(attrs, "validate", "crate = \"typeshift::validator\"");
}
}
fn has_derived_trait(attrs: &[Attribute], trait_name: &str) -> bool {
attrs
.iter()
.filter(|attr| attr.path().is_ident("derive"))
.filter_map(|attr| {
attr.parse_args_with(
syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated,
)
.ok()
})
.flat_map(|paths| paths.into_iter())
.any(|path| {
path.segments
.last()
.map(|seg| seg.ident == trait_name)
.unwrap_or(false)
})
}
fn add_missing_derives(attrs: &mut Vec<Attribute>, required: &[&str]) {
let mut missing = Vec::new();
for name in required {
if has_derived_trait(attrs, name) {
continue;
}
let path: syn::Path = match *name {
"Serialize" => syn::parse_quote!(::typeshift::serde::Serialize),
"Deserialize" => syn::parse_quote!(::typeshift::serde::Deserialize),
"Validate" => syn::parse_quote!(::typeshift::validator::Validate),
"JsonSchema" => syn::parse_quote!(::typeshift::schemars::JsonSchema),
_ => continue,
};
missing.push(path);
}
if !missing.is_empty() {
let insert_at = attrs
.iter()
.rposition(|attr| attr.path().is_ident("derive"))
.map(|index| index + 1)
.unwrap_or(0);
attrs.insert(insert_at, syn::parse_quote!(#[derive(#(#missing),*)]));
}
}
fn ensure_attr(attrs: &mut Vec<Attribute>, name: &str, args: &str) {
let path = syn::Ident::new(name, proc_macro2::Span::call_site());
let args: proc_macro2::TokenStream = match args.parse() {
Ok(args) => args,
Err(_) => return,
};
let has_crate_arg = attrs
.iter()
.any(|attr| attr.path().is_ident(name) && attr_has_crate_arg(attr));
if !has_crate_arg {
attrs.push(syn::parse_quote!(#[#path(#args)]));
}
}
fn attr_has_crate_arg(attr: &Attribute) -> bool {
attr.parse_args_with(syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated)
.map(|metas| {
metas.into_iter().any(|meta| {
if let syn::Meta::NameValue(name_value) = meta {
return name_value.path.is_ident("crate");
}
false
})
})
.unwrap_or(false)
}