1mod attr;
2mod gen;
3mod parse;
4
5use gen::{FieldDef, StructDef};
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{DeriveInput, GenericParam, Ident, Result};
9
10#[proc_macro_derive(Deserialize, attributes(serde))]
17pub fn derive_deserialize(input: TokenStream) -> TokenStream {
18 match derive_deserialize_impl(input) {
19 Ok(output) => output,
20 Err(err) => err.to_compile_error().into(),
21 }
22}
23
24fn derive_deserialize_impl(input: TokenStream) -> Result<TokenStream> {
25 let input: DeriveInput = syn::parse(input)?;
27
28 let _fields = parse::validate_input(&input)?;
30
31 let parsed = attr::parse_struct_attrs(&input)?;
33
34 let struct_def = StructDef {
36 ident: input.ident,
37 generics: input.generics,
38 container_attrs: parsed.container_attrs,
39 fields: parsed
40 .fields
41 .into_iter()
42 .map(|f| FieldDef {
43 ident: f.ident,
44 ty: f.ty,
45 attrs: f.attrs,
46 })
47 .collect(),
48 };
49
50 let deserialize_body = gen::deserialize(&struct_def);
52
53 let deserialize_in_place_body = gen::deserialize_in_place(&struct_def);
55
56 let struct_ident = &struct_def.ident;
58 let (_impl_generics, ty_generics, where_clause) = struct_def.generics.split_for_impl();
59
60 let generic_params = &struct_def.generics.params;
63 let output = if generic_params.is_empty() {
64 quote! {
65 #[automatically_derived]
66 impl<'de> serde::Deserialize<'de> for #struct_ident #where_clause {
67 #deserialize_body
68
69 #deserialize_in_place_body
70 }
71 }
72 } else {
73 let deserializable_type_params: std::collections::HashSet<Ident> = struct_def
75 .fields
76 .iter()
77 .filter(|f| !f.attrs.skip_deserializing)
78 .flat_map(|f| extract_type_params(&f.ty, &struct_def.generics))
79 .collect();
80
81 let impl_params = generic_params.iter().map(|param| match param {
83 GenericParam::Type(type_param) => {
84 let ident = &type_param.ident;
85 let existing_bounds = &type_param.bounds;
86 let needs_deserialize = deserializable_type_params.contains(ident);
87 match (existing_bounds.is_empty(), needs_deserialize) {
88 (true, true) => quote! { #ident: serde::Deserialize<'de> },
89 (true, false) => quote! { #ident },
90 (false, true) => quote! { #ident: #existing_bounds + serde::Deserialize<'de> },
91 (false, false) => quote! { #ident: #existing_bounds },
92 }
93 }
94 GenericParam::Lifetime(lt) => quote! { #lt },
95 GenericParam::Const(cp) => quote! { #cp },
96 });
97
98 let where_clause_output = if let Some(wc) = where_clause {
100 quote! { #wc }
101 } else {
102 quote! {}
103 };
104
105 quote! {
106 #[automatically_derived]
107 impl<'de, #(#impl_params),*> serde::Deserialize<'de> for #struct_ident #ty_generics #where_clause_output {
108 #deserialize_body
109
110 #deserialize_in_place_body
111 }
112 }
113 };
114
115 Ok(output.into())
116}
117
118fn extract_type_params(ty: &syn::Type, generics: &syn::Generics) -> Vec<syn::Ident> {
121 let type_param_idents: Vec<&syn::Ident> = generics
122 .params
123 .iter()
124 .filter_map(|p| match p {
125 GenericParam::Type(tp) => Some(&tp.ident),
126 _ => None,
127 })
128 .collect();
129
130 let mut found = Vec::new();
131 gen::collect_type_param_idents(ty, &type_param_idents, &mut found);
132 found.into_iter().cloned().collect()
133}