use std::collections::HashMap;
use proc_macro_crate::{FoundCrate, crate_name};
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::{ToTokens, quote};
use syn::visit::Visit;
use syn::{
Attribute, Data, DataStruct, DeriveInput, Field, Fields, GenericParam, Generics, Ident, Index,
PathArguments, Type,
};
pub const IS_SERDE_ENABLED: bool = cfg!(feature = "serde");
const PATCHABLE: &str = "patchable";
#[derive(Debug)]
enum TypeUsage {
NotPatchable,
Patchable,
}
pub(crate) struct MacroContext<'a> {
struct_name: &'a Ident,
generics: &'a Generics,
fields: &'a Fields,
preserved_types: HashMap<&'a Ident, TypeUsage>,
field_actions: Vec<FieldAction<'a>>,
patch_struct_name: Ident,
patchable_trait: TokenStream2,
patch_trait: TokenStream2,
}
impl<'a> MacroContext<'a> {
pub(crate) fn new(input: &'a DeriveInput) -> syn::Result<Self> {
let Data::Struct(DataStruct { fields, .. }) = &input.data else {
return Err(syn::Error::new_spanned(
input,
"This derive macro can only be applied to structs",
));
};
if input
.generics
.params
.iter()
.any(|g| matches!(g, GenericParam::Lifetime(_)))
{
return Err(syn::Error::new_spanned(
&input.generics,
"Patch derives do not support borrowed fields",
));
}
let mut preserved_types: HashMap<&Ident, TypeUsage> = HashMap::new();
let mut field_actions = Vec::new();
for (index, field) in fields.iter().enumerate() {
if has_patchable_skip_attr(field) {
continue;
}
let member = if let Some(field_name) = field.ident.as_ref() {
FieldMember::Named(field_name)
} else {
FieldMember::Unnamed(Index::from(index))
};
let field_type = &field.ty;
if has_patchable_attr(field) {
let Some(type_name) = get_abstract_simple_type_name(field_type) else {
return Err(syn::Error::new_spanned(
field_type,
"Only a simple generic type is supported here", ));
};
preserved_types.insert(type_name, TypeUsage::Patchable);
field_actions.push(FieldAction::Patch {
member,
ty: field_type,
});
} else {
for type_name in collect_used_simple_types(field_type) {
preserved_types
.entry(type_name)
.or_insert(TypeUsage::NotPatchable);
}
field_actions.push(FieldAction::Keep {
member,
ty: field_type,
});
};
}
let crate_path = use_site_crate_path();
Ok(MacroContext {
struct_name: &input.ident,
generics: &input.generics,
fields,
preserved_types,
field_actions,
patch_struct_name: quote::format_ident!("{}Patch", &input.ident),
patchable_trait: quote! { #crate_path :: Patchable },
patch_trait: quote! { #crate_path :: Patch },
})
}
pub(crate) fn build_patch_struct(&self) -> TokenStream2 {
let generic_params = self.build_patch_type_generics();
let where_clause = self.build_where_clause_with_bound(&self.patchable_trait);
let patch_fields = self.generate_patch_fields();
let body = match &self.fields {
Fields::Named(_) => quote! { #generic_params #where_clause { #(#patch_fields),* } },
Fields::Unnamed(_) => quote! { #generic_params ( #(#patch_fields),* ) #where_clause; },
Fields::Unit => quote! {;},
};
let patch_name = &self.patch_struct_name;
let derive_attr = if IS_SERDE_ENABLED {
quote! { #[derive(::serde::Deserialize)] }
} else {
quote! {}
};
quote! {
#derive_attr
pub struct #patch_name #body
}
}
pub(crate) fn build_patchable_trait_impl(&self) -> TokenStream2 {
let patchable_trait = &self.patchable_trait;
let (impl_generics, type_generics, _) = self.generics.split_for_impl();
let where_clause = self.build_where_clause_with_bound(patchable_trait);
let assoc_type_decl = self.build_associated_type_declaration();
let input_struct_name = self.struct_name;
quote! {
impl #impl_generics #patchable_trait
for #input_struct_name #type_generics
#where_clause {
#assoc_type_decl
}
}
}
pub(crate) fn build_patch_trait_impl(&self) -> TokenStream2 {
let patch_trait = &self.patch_trait;
let (impl_generics, type_generics, _) = self.generics.split_for_impl();
let where_clause = self.build_where_clause_with_bound(patch_trait);
let input_struct_name = self.struct_name;
let patch_param_name = if self.field_actions.is_empty() {
quote! { _patch }
} else {
quote! { patch }
};
let patch_method_body = self.generate_patch_method_body();
quote! {
impl #impl_generics #patch_trait
for #input_struct_name #type_generics
#where_clause {
#[inline(always)]
fn patch(&mut self, #patch_param_name: Self::Patch) {
#patch_method_body
}
}
}
}
pub(crate) fn build_from_trait_impl(&self) -> TokenStream2 {
let (impl_generics, type_generics, _) = self.generics.split_for_impl();
let patch_type_generics = self.build_patch_type_generics();
let where_clause = self.build_where_clause_for_from_impl();
let input_struct_name = self.struct_name;
let patch_struct_name = &self.patch_struct_name;
let from_method_body = self.build_from_method_body();
quote! {
impl #impl_generics ::core::convert::From<#input_struct_name #type_generics>
for #patch_struct_name #patch_type_generics
#where_clause {
#[inline(always)]
fn from(value: #input_struct_name #type_generics) -> Self {
#from_method_body
}
}
}
}
fn generate_patch_fields(&self) -> Vec<TokenStream2> {
let patchable_trait = &self.patchable_trait;
self.field_actions
.iter()
.map(|action| match action {
FieldAction::Keep { member, ty } => match member {
FieldMember::Named(name) => quote! { #name : #ty },
FieldMember::Unnamed(_) => quote! { #ty },
},
FieldAction::Patch { member, ty } => {
let field = match member {
FieldMember::Named(name) => quote! { #name : <#ty as #patchable_trait>::Patch },
FieldMember::Unnamed(_) => quote! { <#ty as #patchable_trait>::Patch },
};
if IS_SERDE_ENABLED {
let bound = quote! { <#ty as #patchable_trait>::Patch: ::serde::de::DeserializeOwned };
let bound_string = bound.to_string();
let bound_lit = syn::LitStr::new(&bound_string, Span::call_site());
quote! {
#[serde(bound(deserialize = #bound_lit))]
#field
}
} else {
quote! { #field }
}
}
})
.collect()
}
fn generate_patch_method_body(&self) -> TokenStream2 {
if self.field_actions.is_empty() {
return quote! {};
}
let statements = self
.field_actions
.iter()
.enumerate()
.map(|(patch_index, action)| match action {
FieldAction::Keep { member, .. } => {
let patch_member = patch_member(member, patch_index);
quote! {
self.#member = patch.#patch_member;
}
}
FieldAction::Patch { member, .. } => {
let patch_member = patch_member(member, patch_index);
quote! {
self.#member.patch(patch.#patch_member);
}
}
});
quote! {
#(#statements)*
}
}
fn build_from_method_body(&self) -> TokenStream2 {
match &self.fields {
Fields::Named(_) => {
let field_initializers = self.field_actions.iter().map(|action| {
let member = action.member();
let value = action.build_initializer_expr();
quote! { #member: #value }
});
quote! { Self { #(#field_initializers),* } }
}
Fields::Unnamed(_) => {
let field_values = self
.field_actions
.iter()
.map(|action| action.build_initializer_expr());
quote! { Self(#(#field_values),*) }
}
Fields::Unit => {
debug_assert!(self.field_actions.is_empty());
quote! { Self }
}
}
}
fn iter_patchable_type_params(&self) -> impl Iterator<Item = &Ident> + '_ {
self.generics.type_params().filter_map(|param| {
matches!(
self.preserved_types.get(¶m.ident),
Some(TypeUsage::Patchable)
)
.then_some(¶m.ident)
})
}
fn iter_preserved_type_params(&self) -> impl Iterator<Item = &Ident> + '_ {
self.generics.type_params().filter_map(|param| {
self.preserved_types
.contains_key(¶m.ident)
.then_some(¶m.ident)
})
}
fn build_associated_type_declaration(&self) -> TokenStream2 {
let patch_type_generics = self.build_patch_type_generics();
let state_name = &self.patch_struct_name;
quote! {
type Patch = #state_name #patch_type_generics;
}
}
fn build_patch_type_generics(&self) -> TokenStream2 {
let patch_generic_params = self.iter_preserved_type_params();
quote! { <#(#patch_generic_params),*> }
}
fn build_where_clause_with_bound(&self, bound: &TokenStream2) -> TokenStream2 {
self.build_where_clause_for_patchable_types(|ty| quote! { #ty: #bound, })
}
fn build_where_clause_for_from_impl(&self) -> TokenStream2 {
let patchable_trait = &self.patchable_trait;
self.build_where_clause_for_patchable_types(|ty| {
quote! {
#ty: #patchable_trait,
<#ty as #patchable_trait>::Patch: ::core::convert::From<#ty>,
}
})
}
fn build_where_clause_for_patchable_types<F>(&self, build_bounds: F) -> TokenStream2
where
F: Fn(&Ident) -> TokenStream2,
{
let bounded_types: Vec<_> = self
.iter_patchable_type_params()
.map(build_bounds)
.collect();
self.extend_where_clause(bounded_types)
}
fn extend_where_clause(&self, bounds: Vec<TokenStream2>) -> TokenStream2 {
match (&self.generics.where_clause, bounds.is_empty()) {
(None, true) => quote! {},
(None, false) => quote! { where #(#bounds)* },
(Some(where_clause), true) => quote! { #where_clause },
(Some(where_clause), false) => {
let sep = (!where_clause.predicates.trailing_punct()).then_some(quote! {,});
quote! { #where_clause #sep #(#bounds)* }
}
}
}
}
enum FieldMember<'a> {
Named(&'a Ident),
Unnamed(Index),
}
impl<'a> ToTokens for FieldMember<'a> {
fn to_tokens(&self, tokens: &mut TokenStream2) {
match self {
FieldMember::Named(ident) => ident.to_tokens(tokens),
FieldMember::Unnamed(index) => index.to_tokens(tokens),
}
}
}
enum FieldAction<'a> {
Keep {
member: FieldMember<'a>,
ty: &'a Type,
},
Patch {
member: FieldMember<'a>,
ty: &'a Type,
},
}
impl<'a> FieldAction<'a> {
fn member(&self) -> &FieldMember<'a> {
match self {
FieldAction::Keep { member, .. } | FieldAction::Patch { member, .. } => member,
}
}
fn build_initializer_expr(&self) -> TokenStream2 {
let member = self.member();
match self {
FieldAction::Keep { .. } => quote! { value.#member },
FieldAction::Patch { .. } => quote! { ::core::convert::From::from(value.#member) },
}
}
}
fn patch_member(member: &FieldMember<'_>, patch_index: usize) -> TokenStream2 {
match member {
FieldMember::Named(name) => quote! { #name },
FieldMember::Unnamed(_) => {
let index = Index::from(patch_index);
quote! { #index }
}
}
}
pub fn use_site_crate_path() -> TokenStream2 {
let found_crate =
crate_name(PATCHABLE).expect("patchable library should be present in `Cargo.toml`");
match found_crate {
FoundCrate::Itself => quote! { crate },
FoundCrate::Name(name) => {
let ident = Ident::new(&name, Span::call_site());
quote!( ::#ident )
}
}
}
#[inline]
fn is_patchable_attr(attr: &Attribute) -> bool {
attr.path().is_ident(PATCHABLE)
}
fn patchable_attr_has_param(attr: &Attribute, param: &str) -> bool {
is_patchable_attr(attr)
&& attr
.parse_nested_meta(|meta| {
if meta.path.is_ident(param) {
Ok(())
} else {
Err(meta.error("unrecognized `patchable` parameter"))
}
})
.is_ok()
}
fn has_patchable_attr(field: &Field) -> bool {
field.attrs.iter().any(is_patchable_attr)
}
pub fn has_patchable_skip_attr(field: &Field) -> bool {
field
.attrs
.iter()
.any(|attr| patchable_attr_has_param(attr, "skip"))
}
struct SimpleTypeCollector<'a> {
used_simple_types: Vec<&'a Ident>,
}
impl<'ast> Visit<'ast> for SimpleTypeCollector<'ast> {
fn visit_type_path(&mut self, node: &'ast syn::TypePath) {
if node.qself.is_none()
&& let Some(segment) = node.path.segments.first()
{
self.used_simple_types.push(&segment.ident);
}
syn::visit::visit_type_path(self, node);
}
}
fn collect_used_simple_types(ty: &Type) -> Vec<&Ident> {
let mut collector = SimpleTypeCollector {
used_simple_types: Vec::new(),
};
collector.visit_type(ty);
collector.used_simple_types
}
fn get_abstract_simple_type_name(t: &Type) -> Option<&Ident> {
match t {
Type::Path(tp) if !tp.path.segments.is_empty() => {
let last_segment = tp.path.segments.last()?;
if matches!(last_segment.arguments, PathArguments::None) {
Some(&last_segment.ident)
} else {
None
}
}
_ => None,
}
}