mod attr;
mod gen;
mod parse;
use gen::{FieldDef, StructDef};
use proc_macro::TokenStream;
use quote::quote;
use syn::{DeriveInput, GenericParam, Ident, Result};
#[proc_macro_derive(Deserialize, attributes(serde))]
pub fn derive_deserialize(input: TokenStream) -> TokenStream {
match derive_deserialize_impl(input) {
Ok(output) => output,
Err(err) => err.to_compile_error().into(),
}
}
fn derive_deserialize_impl(input: TokenStream) -> Result<TokenStream> {
let input: DeriveInput = syn::parse(input)?;
let _fields = parse::validate_input(&input)?;
let parsed = attr::parse_struct_attrs(&input)?;
let struct_def = StructDef {
ident: input.ident,
generics: input.generics,
container_attrs: parsed.container_attrs,
fields: parsed
.fields
.into_iter()
.map(|f| FieldDef {
ident: f.ident,
ty: f.ty,
attrs: f.attrs,
})
.collect(),
};
let deserialize_body = gen::deserialize(&struct_def);
let deserialize_in_place_body = gen::deserialize_in_place(&struct_def);
let struct_ident = &struct_def.ident;
let (_impl_generics, ty_generics, where_clause) = struct_def.generics.split_for_impl();
let generic_params = &struct_def.generics.params;
let output = if generic_params.is_empty() {
quote! {
#[automatically_derived]
impl<'de> serde::Deserialize<'de> for #struct_ident #where_clause {
#deserialize_body
#deserialize_in_place_body
}
}
} else {
let deserializable_type_params: std::collections::HashSet<Ident> = struct_def
.fields
.iter()
.filter(|f| !f.attrs.skip_deserializing)
.flat_map(|f| extract_type_params(&f.ty, &struct_def.generics))
.collect();
let impl_params = generic_params.iter().map(|param| match param {
GenericParam::Type(type_param) => {
let ident = &type_param.ident;
let existing_bounds = &type_param.bounds;
let needs_deserialize = deserializable_type_params.contains(ident);
match (existing_bounds.is_empty(), needs_deserialize) {
(true, true) => quote! { #ident: serde::Deserialize<'de> },
(true, false) => quote! { #ident },
(false, true) => quote! { #ident: #existing_bounds + serde::Deserialize<'de> },
(false, false) => quote! { #ident: #existing_bounds },
}
}
GenericParam::Lifetime(lt) => quote! { #lt },
GenericParam::Const(cp) => quote! { #cp },
});
let where_clause_output = if let Some(wc) = where_clause {
quote! { #wc }
} else {
quote! {}
};
quote! {
#[automatically_derived]
impl<'de, #(#impl_params),*> serde::Deserialize<'de> for #struct_ident #ty_generics #where_clause_output {
#deserialize_body
#deserialize_in_place_body
}
}
};
Ok(output.into())
}
fn extract_type_params(ty: &syn::Type, generics: &syn::Generics) -> Vec<syn::Ident> {
let type_param_idents: Vec<&syn::Ident> = generics
.params
.iter()
.filter_map(|p| match p {
GenericParam::Type(tp) => Some(&tp.ident),
_ => None,
})
.collect();
let mut found = Vec::new();
gen::collect_type_param_idents(ty, &type_param_idents, &mut found);
found.into_iter().cloned().collect()
}