Skip to main content

config_derive/
lib.rs

1mod attr;
2mod gen;
3mod parse;
4
5use gen::{FieldDef, StructDef};
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{DeriveInput, GenericParam, Ident, Result};
9
10/// Derive macro that generates a custom `impl serde::Deserialize<'de>` for named structs.
11///
12/// # Remarks
13///
14/// Unlike serde's default derive, this macro produces a `deserialize_in_place()` method that only updates fields
15/// present in the deserializer's map, leaving absent fields at their current values.
16#[proc_macro_derive(Deserialize, attributes(serde))]
17pub fn derive_deserialize(input: TokenStream) -> TokenStream {
18    match derive_deserialize_impl(input) {
19        Ok(output) => output,
20        Err(err) => err.to_compile_error().into(),
21    }
22}
23
24fn derive_deserialize_impl(input: TokenStream) -> Result<TokenStream> {
25    // 1. parse the input TokenStream into a DeriveInput
26    let input: DeriveInput = syn::parse(input)?;
27
28    // 2. validate that the input is a named struct
29    let _fields = parse::validate_input(&input)?;
30
31    // 3. parse serde attributes from the struct
32    let parsed = attr::parse_struct_attrs(&input)?;
33
34    // 4. build the StructDef for code generation
35    let struct_def = StructDef {
36        ident: input.ident,
37        generics: input.generics,
38        container_attrs: parsed.container_attrs,
39        fields: parsed
40            .fields
41            .into_iter()
42            .map(|f| FieldDef {
43                ident: f.ident,
44                ty: f.ty,
45                attrs: f.attrs,
46            })
47            .collect(),
48    };
49
50    // 5. generate the deserialize method
51    let deserialize_body = gen::deserialize(&struct_def);
52
53    // 6. generate the deserialize_in_place method
54    let deserialize_in_place_body = gen::deserialize_in_place(&struct_def);
55
56    // 7. build the full impl block
57    let struct_ident = &struct_def.ident;
58    let (_impl_generics, ty_generics, where_clause) = struct_def.generics.split_for_impl();
59
60    // for non-generic structs, we just need `impl<'de>`. for generic structs, we add `T: serde::Deserialize<'de>`
61    // bounds only for type parameters that appear in deserializable (non-skipped) fields
62    let generic_params = &struct_def.generics.params;
63    let output = if generic_params.is_empty() {
64        quote! {
65            #[automatically_derived]
66            impl<'de> serde::Deserialize<'de> for #struct_ident #where_clause {
67                #deserialize_body
68
69                #deserialize_in_place_body
70            }
71        }
72    } else {
73        // collect type parameter identifiers that appear in deserializable fields
74        let deserializable_type_params: std::collections::HashSet<Ident> = struct_def
75            .fields
76            .iter()
77            .filter(|f| !f.attrs.skip_deserializing)
78            .flat_map(|f| extract_type_params(&f.ty, &struct_def.generics))
79            .collect();
80
81        // build impl generics with Deserialize<'de> bounds only on type parameters used in deserializable fields
82        let impl_params = generic_params.iter().map(|param| match param {
83            GenericParam::Type(type_param) => {
84                let ident = &type_param.ident;
85                let existing_bounds = &type_param.bounds;
86                let needs_deserialize = deserializable_type_params.contains(ident);
87                match (existing_bounds.is_empty(), needs_deserialize) {
88                    (true, true) => quote! { #ident: serde::Deserialize<'de> },
89                    (true, false) => quote! { #ident },
90                    (false, true) => quote! { #ident: #existing_bounds + serde::Deserialize<'de> },
91                    (false, false) => quote! { #ident: #existing_bounds },
92                }
93            }
94            GenericParam::Lifetime(lt) => quote! { #lt },
95            GenericParam::Const(cp) => quote! { #cp },
96        });
97
98        // build the where clause, combining existing predicates with any additional ones
99        let where_clause_output = if let Some(wc) = where_clause {
100            quote! { #wc }
101        } else {
102            quote! {}
103        };
104
105        quote! {
106            #[automatically_derived]
107            impl<'de, #(#impl_params),*> serde::Deserialize<'de> for #struct_ident #ty_generics #where_clause_output {
108                #deserialize_body
109
110                #deserialize_in_place_body
111            }
112        }
113    };
114
115    Ok(output.into())
116}
117
118/// Extracts type parameter identifiers from a type that match the struct's generic parameters.
119/// For example, given `Vec<T>` and generics `<T, U>`, this returns `{T}`.
120fn extract_type_params(ty: &syn::Type, generics: &syn::Generics) -> Vec<syn::Ident> {
121    let type_param_idents: Vec<&syn::Ident> = generics
122        .params
123        .iter()
124        .filter_map(|p| match p {
125            GenericParam::Type(tp) => Some(&tp.ident),
126            _ => None,
127        })
128        .collect();
129
130    let mut found = Vec::new();
131    gen::collect_type_param_idents(ty, &type_param_idents, &mut found);
132    found.into_iter().cloned().collect()
133}