Skip to main content

dbt_yaml_derive/
lib.rs

1extern crate proc_macro2;
2extern crate quote;
3extern crate syn;
4
5extern crate proc_macro;
6
7use std::str::FromStr;
8
9use heck::ToKebabCase as _;
10use heck::ToLowerCamelCase as _;
11use heck::ToPascalCase as _;
12use heck::ToSnakeCase as _;
13use proc_macro::TokenStream;
14use quote::quote;
15use syn::DeriveInput;
16use syn::parse_macro_input;
17use syn::spanned::Spanned;
18
19struct Variant<'a> {
20    ident: syn::Ident,
21    fields: &'a syn::Fields,
22}
23
24impl<'a> Variant<'a> {
25    pub fn try_from_ast(variant: &'a syn::Variant) -> syn::Result<Self> {
26        if variant
27            .attrs
28            .iter()
29            .any(|attr| attr.path().is_ident("serde"))
30        {
31            return Err(syn::Error::new(
32                variant.span(),
33                "UntaggedEnumDeserialize: #[serde(..)] attributes on variants are not supported",
34            ));
35        }
36
37        Ok(Variant {
38            ident: variant.ident.clone(),
39            fields: &variant.fields,
40        })
41    }
42
43    fn gen_untagged_type_name(&self) -> syn::Result<proc_macro2::TokenStream> {
44        match self.fields {
45            syn::Fields::Unit => Ok(quote! { <() as __serde::Deserialize> }),
46            syn::Fields::Unnamed(fields) => {
47                if fields.unnamed.len() == 1 {
48                    // If there's only one unnamed field, we can use its type directly
49                    let ty = &fields.unnamed[0].ty;
50                    Ok(quote! { <#ty as __serde::Deserialize> })
51                } else {
52                    // If there are multiple unnamed fields, we create a tuple type
53                    let types = fields
54                        .unnamed
55                        .iter()
56                        .map(|f| f.ty.clone())
57                        .collect::<Vec<_>>();
58                    Ok(quote! { <(#(#types),*) as __serde::Deserialize> })
59                }
60            }
61            syn::Fields::Named(_) => Err(syn::Error::new(
62                self.ident.span(),
63                "UntaggedEnumDeserialize: inlined struct variants are not supported -- use a named struct type instead",
64            )),
65        }
66    }
67
68    fn gen_constructor(&self) -> syn::Result<proc_macro2::TokenStream> {
69        let enum_name = &self.ident;
70        match self.fields {
71            syn::Fields::Unit => Ok(quote! { #enum_name }),
72            syn::Fields::Unnamed(fields) => {
73                if fields.unnamed.len() == 1 {
74                    Ok(quote! { #enum_name(__inner) })
75                } else {
76                    let elems = (0..fields.unnamed.len())
77                        .map(|i| {
78                            let i = syn::Index::from(i);
79                            quote! { __inner.#i }
80                        })
81                        .collect::<Vec<proc_macro2::TokenStream>>();
82                    Ok(quote! { #enum_name(#(#elems),*) })
83                }
84            }
85            syn::Fields::Named(_) => Err(syn::Error::new(
86                self.ident.span(),
87                "UntaggedEnumDeserialize: inlined struct variants are not supported -- use a named struct type instead",
88            )),
89        }
90    }
91
92    fn get_name(&self, default_rename_policy: Option<RenamePolicy>) -> String {
93        if let Some(policy) = default_rename_policy {
94            policy.apply(&self.ident)
95        } else {
96            self.ident.to_string()
97        }
98    }
99
100    fn gen_tagged_deserialize_expr(
101        &self,
102        enum_name: &syn::Ident,
103    ) -> syn::Result<proc_macro2::TokenStream> {
104        match self.fields {
105            syn::Fields::Unit => {
106                let enum_name = enum_name.to_string();
107                let variant_name = self.ident.to_string();
108
109                Ok(quote! {
110                    __serde::Deserializer::deserialize_any(
111                        __deserializer,
112                        __serde_yaml::__private::InternallyTaggedUnitVisitor::new(
113                            #enum_name,
114                            #variant_name
115                        )
116                    )
117                })
118            }
119            syn::Fields::Unnamed(fields) => {
120                if fields.unnamed.len() == 1 {
121                    let ty = &fields.unnamed[0].ty;
122
123                    Ok(quote! {
124                        <#ty as __serde::Deserialize>::deserialize(__deserializer)
125                    })
126                } else {
127                    Err(syn::Error::new(
128                        self.ident.span(),
129                        "UntaggedEnumDeserialize: tuple variants are not allowed in internally tagged enums",
130                    ))
131                }
132            }
133            syn::Fields::Named(_) => Err(syn::Error::new(
134                self.ident.span(),
135                "UntaggedEnumDeserialize: inlined struct variants are not supported -- use a named struct type instead",
136            )),
137        }
138    }
139
140    fn gen_tagged_deserialize_arm(
141        &self,
142        enum_name: &syn::Ident,
143        default_rename_policy: Option<RenamePolicy>,
144    ) -> syn::Result<proc_macro2::TokenStream> {
145        let expr = self.gen_tagged_deserialize_expr(enum_name)?;
146        let constructor = self.gen_constructor()?;
147        let tag_name = if let Some(policy) = default_rename_policy {
148            policy.apply(&self.ident)
149        } else {
150            self.ident.to_string()
151        };
152
153        let block = quote! {
154            Some(#tag_name) => {
155                let __inner = #expr.map_err(|e| {
156                    __serde::de::Error::custom(e)
157                })?;
158                return Ok(#enum_name::#constructor);
159            }
160        };
161
162        Ok(block)
163    }
164
165    fn gen_untagged_deserialize_block(&self) -> syn::Result<proc_macro2::TokenStream> {
166        let type_name = self.gen_untagged_type_name()?;
167
168        let block = quote! {
169            __unused_keys.clear();
170            let __inner = {
171                let mut collect_unused_keys =
172                    |path: __serde_yaml::Path<'_>, key: &__serde_yaml::Value, value: &__serde_yaml::Value| {
173                        __unused_keys.push((path.to_owned_path(), key.clone(), value.clone()));
174                    };
175
176                #type_name::deserialize(__state.get_deserializer(Some(&mut collect_unused_keys)))
177            };
178        };
179
180        Ok(block)
181    }
182
183    fn gen_constructor_block(
184        &self,
185        enum_name: &syn::Ident,
186    ) -> syn::Result<proc_macro2::TokenStream> {
187        let constructor = self.gen_constructor()?;
188
189        let block = quote! {
190            if let Ok(__inner) = __inner {
191                if let Some(mut __callback) = __unused_key_callback {
192                    for (path, key, value) in __unused_keys.iter() {
193                        __callback(*path.as_path(), key, value);
194                    }
195                }
196                return Ok(#enum_name::#constructor);
197            }
198        };
199
200        Ok(block)
201    }
202}
203
204#[allow(clippy::enum_variant_names)]
205#[derive(Debug, Clone, Copy, PartialEq, Eq)]
206enum RenamePolicy {
207    /// Rename the field to its snake_case equivalent
208    SnakeCase,
209    /// Rename the field to its camelCase equivalent
210    CamelCase,
211    /// Rename the field to its lower_case equivalent
212    LowerCase,
213    /// Rename the field to its UPPER_CASE equivalent
214    UpperCase,
215    /// Rename the field to its PascalCase equivalent
216    PascalCase,
217    /// Rename the field to its kebab-case equivalent
218    KebabCase,
219}
220
221impl FromStr for RenamePolicy {
222    type Err = syn::Error;
223
224    fn from_str(s: &str) -> Result<Self, Self::Err> {
225        match s {
226            "snake_case" => Ok(RenamePolicy::SnakeCase),
227            "camelCase" => Ok(RenamePolicy::CamelCase),
228            "lowercase" => Ok(RenamePolicy::LowerCase),
229            "UPPERCASE" => Ok(RenamePolicy::UpperCase),
230            "PascalCase" => Ok(RenamePolicy::PascalCase),
231            "kebab-case" => Ok(RenamePolicy::KebabCase),
232            _ => Err(syn::Error::new(
233                proc_macro2::Span::call_site(),
234                format!("Unknown rename policy: {s}"),
235            )),
236        }
237    }
238}
239
240impl RenamePolicy {
241    fn apply(&self, ident: &syn::Ident) -> String {
242        match self {
243            RenamePolicy::SnakeCase => ident.to_string().to_snake_case(),
244            RenamePolicy::CamelCase => ident.to_string().to_lower_camel_case(),
245            RenamePolicy::LowerCase => ident.to_string().to_lowercase(),
246            RenamePolicy::UpperCase => ident.to_string().to_uppercase(),
247            RenamePolicy::PascalCase => ident.to_string().to_pascal_case(),
248            RenamePolicy::KebabCase => ident.to_string().to_kebab_case(),
249        }
250    }
251}
252
253struct EnumDef<'a> {
254    ident: syn::Ident,
255    generics: &'a syn::Generics,
256    variants: Vec<Variant<'a>>,
257    tag: Option<String>,
258    rename_all: Option<RenamePolicy>,
259}
260
261impl<'a> EnumDef<'a> {
262    pub fn try_from_ast(input: &'a DeriveInput) -> syn::Result<Self> {
263        // Check if the input is an enum
264        let syn::Data::Enum(data_enum) = &input.data else {
265            return Err(syn::Error::new(
266                input.span(),
267                "UntaggedEnumDeserialize: can only be derived for enums",
268            ));
269        };
270
271        // Check for #[serde(untagged)] attribute
272        let has_untagged_attr = input.attrs.iter().any(|attr| {
273            if !attr.path().is_ident("serde") {
274                return false;
275            }
276            if let Ok(syn::Expr::Path(expr_path)) = attr.parse_args() {
277                return expr_path.path.is_ident("untagged");
278            }
279            false
280        });
281        // Check for #[serde(tag = "...")] attribute
282        let tag_attr = input.attrs.iter().find_map(|attr| {
283            if !attr.path().is_ident("serde") {
284                return None;
285            }
286            let Ok(syn::Expr::Assign(expr)) = attr.parse_args() else {
287                return None;
288            };
289            let syn::Expr::Path(expr_path) = *expr.left else {
290                return None;
291            };
292            if !expr_path.path.is_ident("tag") {
293                return None;
294            }
295
296            match *expr.right {
297                syn::Expr::Lit(lit) => {
298                    match lit.lit {
299                        syn::Lit::Str(lit) => Some(lit.value()),
300                        _ => None, // Invalid tag attribute
301                    }
302                }
303                _ => None,
304            }
305        });
306
307        if !has_untagged_attr && tag_attr.is_none() {
308            return Err(syn::Error::new(
309                input.span(),
310                "UntaggedEnumDeserialize: can only be derived for enums with #[serde(untagged)] or #[serde(tag = \"...\")] attributes",
311            ));
312        }
313
314        // Extract any #[serde(rename_all = "...")] directives
315        let rename_all_attr = input.attrs.iter().find_map(|attr| {
316            if !attr.path().is_ident("serde") {
317                return None;
318            }
319            let Ok(syn::Expr::Assign(expr)) = attr.parse_args() else {
320                return None;
321            };
322            let syn::Expr::Path(expr_path) = *expr.left else {
323                return None;
324            };
325            if !expr_path.path.is_ident("rename_all") {
326                return None;
327            }
328
329            match *expr.right {
330                syn::Expr::Lit(lit) => {
331                    match lit.lit {
332                        syn::Lit::Str(lit) => Some(lit.value()),
333                        _ => None, // Invalid rename_all attribute
334                    }
335                }
336                _ => None,
337            }
338        });
339        let rename_all = rename_all_attr
340            .map(|a| RenamePolicy::from_str(a.as_str()))
341            .transpose()?;
342
343        // Check the enum has no borrowed lifetimes
344        for param in &input.generics.params {
345            if let syn::GenericParam::Lifetime(lifetime_param) = param {
346                return Err(syn::Error::new(
347                    lifetime_param.lifetime.span(),
348                    "UntaggedEnumDeserialize: borrowed lifetimes are not supported",
349                ));
350            }
351        }
352
353        let ident = input.ident.clone();
354        let generics = &input.generics;
355        let variants = data_enum
356            .variants
357            .iter()
358            .map(Variant::try_from_ast)
359            .collect::<syn::Result<Vec<_>>>()?;
360        Ok(EnumDef {
361            ident,
362            generics,
363            variants,
364            tag: tag_attr,
365            rename_all,
366        })
367    }
368
369    fn build_impl_generics(&self) -> syn::Generics {
370        let mut generics = self.generics.clone();
371        // Inject a 'de lifetime bound for deserialization
372        generics
373            .params
374            .push(syn::GenericParam::Lifetime(syn::LifetimeParam {
375                attrs: Vec::new(),
376                lifetime: syn::Lifetime::new("'de", self.ident.span()),
377                colon_token: None,
378                bounds: syn::punctuated::Punctuated::new(),
379            }));
380
381        // Inject a where clause bound `T: serde::de::Deserialize<'_>` for each
382        // non-lifetime type parameter `T`:
383        for param in &mut generics.params {
384            if let syn::GenericParam::Type(ty_param) = param {
385                ty_param
386                    .bounds
387                    .push(syn::parse_quote!(__serde::de::DeserializeOwned));
388            }
389        }
390
391        generics
392    }
393
394    fn gen_untagged_impl(&self) -> syn::Result<proc_macro2::TokenStream> {
395        let enum_name = &self.ident;
396        let generics = self.build_impl_generics();
397        let (impl_generics, _, where_clause) = generics.split_for_impl();
398        let (_, ty_generics, _) = self.generics.split_for_impl();
399
400        let mut variant_blocks = Vec::new();
401        for variant in &self.variants {
402            let deserialize_block = variant.gen_untagged_deserialize_block()?;
403            let constructor_block = variant.gen_constructor_block(enum_name)?;
404            variant_blocks.push(quote! {
405                #deserialize_block
406                #constructor_block
407            });
408        }
409
410        let err_message = format!("data did not match any variant of untagged enum {enum_name}");
411
412        Ok(quote! {
413            #[automatically_derived]
414            impl #impl_generics __serde::Deserialize<'de> for #enum_name #ty_generics #where_clause {
415                fn deserialize<__D>(deserializer: __D) -> Result<Self, __D::Error>
416                where
417                    __D: __serde::de::Deserializer<'de>,
418                {
419                    let mut __state = __serde_yaml::value::extract_reusable_deserializer_state(deserializer)?;
420                    let __unused_key_callback = __state.take_unused_key_callback();
421                    let mut __unused_keys = vec![];
422
423                    #( #variant_blocks )*
424
425                    Err(__serde::de::Error::custom(#err_message))
426                }
427            }
428        })
429    }
430
431    fn gen_internally_tagged_impl(&self) -> syn::Result<proc_macro2::TokenStream> {
432        let enum_name = &self.ident;
433        let tag_key = self.tag.as_ref().expect("Expected tag key");
434        let generics = self.build_impl_generics();
435        let (impl_generics, _, where_clause) = generics.split_for_impl();
436        let (_, ty_generics, _) = self.generics.split_for_impl();
437
438        let variant_arms = self
439            .variants
440            .iter()
441            .map(|variant| variant.gen_tagged_deserialize_arm(enum_name, self.rename_all))
442            .collect::<syn::Result<Vec<_>>>()?;
443        let variant_names = self
444            .variants
445            .iter()
446            .map(|variant| variant.get_name(self.rename_all))
447            .collect::<Vec<_>>();
448
449        Ok(quote! {
450            #[automatically_derived]
451            impl #impl_generics __serde::Deserialize<'de> for #enum_name #ty_generics #where_clause {
452                fn deserialize<__D>(deserializer: __D) -> Result<Self, __D::Error>
453                where
454                    __D: __serde::de::Deserializer<'de>,
455                {
456                    let (__tag, mut __state) = __serde_yaml::value::extract_tag_and_deserializer_state(deserializer, #tag_key)?;
457                    let __deserializer = __state.get_owned_deserializer();
458
459                    match __tag.as_str() {
460                        #( #variant_arms )*
461                        Some(tag) => {
462                            return Err(__serde::de::Error::unknown_variant(
463                                tag,
464                                &[ #( #variant_names ),* ]
465                             ));
466                        }
467                        None => {
468                            return Err(__serde::de::Error::invalid_value(
469                                __tag.unexpected(),
470                                &"a valid tag for internally tagged enum"
471                            ));
472                        }
473                    }
474                }
475            }
476        })
477    }
478
479    fn gen_deserialize_impl(&self) -> syn::Result<proc_macro2::TokenStream> {
480        match self.tag {
481            Some(_) => self.gen_internally_tagged_impl(),
482            None => self.gen_untagged_impl(),
483        }
484    }
485}
486
487fn expand_derive_deserialize(
488    input: &mut syn::DeriveInput,
489) -> syn::Result<proc_macro2::TokenStream> {
490    let enum_def = EnumDef::try_from_ast(input)?;
491    let deserialize_impl = enum_def.gen_deserialize_impl()?;
492
493    let block = quote! {
494        const _: () = {
495            #[allow(unused_extern_crates, clippy::useless_attribute)]
496            extern crate dbt_yaml as __serde_yaml;
497            #[allow(unused_extern_crates, clippy::useless_attribute)]
498            extern crate serde as __serde;
499            #deserialize_impl
500        };
501    };
502
503    Ok(block)
504}
505
506#[proc_macro_derive(UntaggedEnumDeserialize, attributes(serde))]
507pub fn derive_deserialize(input: TokenStream) -> TokenStream {
508    let mut input = parse_macro_input!(input as DeriveInput);
509
510    expand_derive_deserialize(&mut input)
511        .unwrap_or_else(syn::Error::into_compile_error)
512        .into()
513}