extern crate proc_macro;
use convert_case::{Case, Casing};
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{
Data, DeriveInput, FieldsNamed, Generics, Ident, ImplGenerics, Path, PathArguments, Type,
TypeGenerics, TypePath, Visibility, WhereClause, parse_macro_input,
};
#[proc_macro_derive(Partial)]
pub fn derive_partial(item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as DeriveInput);
let orig_vis: Visibility = input.vis.clone();
let orig_ident: Ident = input.ident.clone();
let partial_ident = format_ident!("{}Partial", orig_ident);
let maybe_derive_attr = collect_partial_derives(&input);
let Data::Struct(input_struct) = input.data else {
panic!("Optifier supports only struct types");
};
let generics = input.generics.clone();
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let syn::Fields::Named(fields) = &input_struct.fields else {
panic!("Optifier supports only named fields");
};
let partial_struct_def = construct_partial_struct(
&partial_ident,
&orig_vis,
fields,
&generics,
where_clause,
maybe_derive_attr,
);
let merge_function_impl_block = construct_merge_impl_block(
&partial_ident,
fields,
&impl_generics,
&ty_generics,
where_clause,
);
let tryfrom_impl_block = construct_tryfrom_impl_block(
&orig_ident,
&partial_ident,
fields,
&impl_generics,
&ty_generics,
where_clause,
);
let generated_code = quote! {
#partial_struct_def
#merge_function_impl_block
#tryfrom_impl_block
};
TokenStream::from(generated_code)
}
#[proc_macro_attribute]
pub fn partial_derive(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
fn collect_partial_derives(input: &syn::DeriveInput) -> proc_macro2::TokenStream {
let mut derives: Vec<proc_macro2::TokenStream> = Vec::new();
for attr in &input.attrs {
if !attr.path().is_ident("partial_derive") {
continue;
}
let _ = attr.parse_nested_meta(|meta| {
let path = &meta.path;
derives.push(quote! { #path });
Ok(())
});
}
if derives.is_empty() {
quote! {}
} else {
quote! { #[derive( #(#derives),* )] }
}
}
fn is_option_type(ty: &Type) -> bool {
match ty {
Type::Path(TypePath { path, .. }) => is_path_option(path),
_ => false,
}
}
fn is_path_option(path: &Path) -> bool {
if let Some(last) = path.segments.last() {
last.ident == "Option" && matches!(last.arguments, PathArguments::AngleBracketed(_))
} else {
false
}
}
fn construct_partial_struct(
type_ident: &Ident,
type_vis: &Visibility,
fields_named: &FieldsNamed,
generics: &Generics,
where_clause: Option<&WhereClause>,
derive_attrs: proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
let partial_fields = fields_named.named.iter().map(|f| {
let f_vis = &f.vis;
let f_ident = f
.ident
.as_ref()
.expect("Optifier: Named field must have ident");
let f_ty = &f.ty;
if is_option_type(f_ty) {
quote! {
#f_vis #f_ident: #f_ty
}
} else {
quote! {
#f_vis #f_ident: ::std::option::Option<#f_ty>
}
}
});
let partial_struct_def = quote! {
#derive_attrs
#type_vis struct #type_ident #generics #where_clause {
#(#partial_fields),*
}
};
quote! {
#partial_struct_def
}
}
fn construct_merge_impl_block(
type_ident: &Ident,
fields_named: &FieldsNamed,
impl_generics: &ImplGenerics,
ty_generics: &TypeGenerics,
where_clause: Option<&WhereClause>,
) -> proc_macro2::TokenStream {
let fields_merged = fields_named.named.iter().map(|f| {
let f_ident = f
.ident
.as_ref()
.expect("Optifier: Named field must have ident");
quote! {
#f_ident: self.#f_ident.or(other.#f_ident)
}
});
let merge_function_impl = quote! {
pub fn merge(self, other: #type_ident) -> Self {
Self {
#(#fields_merged),*
}
}
};
quote! {
impl #impl_generics #type_ident #ty_generics #where_clause {
#merge_function_impl
}
}
}
fn construct_tryfrom_impl_block(
orig_ident: &Ident,
partial_ident: &Ident,
fields_named: &FieldsNamed,
impl_generics: &ImplGenerics,
ty_generics: &TypeGenerics,
where_clause: Option<&WhereClause>,
) -> proc_macro2::TokenStream {
let error_ident = format_ident!("{}Error", partial_ident);
let error_variants = fields_named.named.iter().filter_map(|f| {
let f_ident = f
.ident
.as_ref()
.expect("Optifier: Named field must have ident");
let f_ty = &f.ty;
if is_option_type(f_ty) {
return None;
}
let f_name_str = f_ident.to_string();
let f_name_in_pascal_case = f_name_str.to_case(Case::Pascal);
let variant_name = format!("{}Missing", f_name_in_pascal_case);
let variant_ident = format_ident!("{}", variant_name);
Some(quote! {
#[error("Field `{}` is missing", #f_name_str)]
#variant_ident
})
});
let construct_fields = fields_named.named.iter().map(|f| {
let f_ident = f
.ident
.as_ref()
.expect("Optifier: Named field must have ident");
let f_ty = &f.ty;
if is_option_type(f_ty) {
quote! {
#f_ident: partial.#f_ident
}
} else {
let raw_name = f_ident.to_string();
let pascal = raw_name.to_case(Case::Pascal);
let variant_name = format!("{}Missing", pascal);
let variant_ident = format_ident!("{}", variant_name);
quote! {
#f_ident: partial.#f_ident.ok_or(#error_ident::#variant_ident)?
}
}
});
let error_def = quote! {
#[derive(::thiserror::Error, Debug)]
pub enum #error_ident {
#(#error_variants),*
}
};
let try_from_impl = quote! {
impl #impl_generics ::std::convert::TryFrom<#partial_ident #ty_generics> for #orig_ident #ty_generics #where_clause {
type Error = #error_ident;
fn try_from(partial: #partial_ident #ty_generics) -> ::std::result::Result<#orig_ident #ty_generics, Self::Error> {
Ok(#orig_ident {
#(#construct_fields),*
})
}
}
};
quote! {
#error_def
#try_from_impl
}
}