1extern crate proc_macro2;
2extern crate quote;
3extern crate syn;
4
5extern crate proc_macro;
6
7use std::str::FromStr;
8
9use heck::ToKebabCase as _;
10use heck::ToLowerCamelCase as _;
11use heck::ToPascalCase as _;
12use heck::ToSnakeCase as _;
13use proc_macro::TokenStream;
14use quote::quote;
15use syn::DeriveInput;
16use syn::parse_macro_input;
17use syn::spanned::Spanned;
18
19struct Variant<'a> {
20 ident: syn::Ident,
21 fields: &'a syn::Fields,
22}
23
24impl<'a> Variant<'a> {
25 pub fn try_from_ast(variant: &'a syn::Variant) -> syn::Result<Self> {
26 if variant
27 .attrs
28 .iter()
29 .any(|attr| attr.path().is_ident("serde"))
30 {
31 return Err(syn::Error::new(
32 variant.span(),
33 "UntaggedEnumDeserialize: #[serde(..)] attributes on variants are not supported",
34 ));
35 }
36
37 Ok(Variant {
38 ident: variant.ident.clone(),
39 fields: &variant.fields,
40 })
41 }
42
43 fn gen_untagged_type_name(&self) -> syn::Result<proc_macro2::TokenStream> {
44 match self.fields {
45 syn::Fields::Unit => Ok(quote! { <() as __serde::Deserialize> }),
46 syn::Fields::Unnamed(fields) => {
47 if fields.unnamed.len() == 1 {
48 let ty = &fields.unnamed[0].ty;
50 Ok(quote! { <#ty as __serde::Deserialize> })
51 } else {
52 let types = fields
54 .unnamed
55 .iter()
56 .map(|f| f.ty.clone())
57 .collect::<Vec<_>>();
58 Ok(quote! { <(#(#types),*) as __serde::Deserialize> })
59 }
60 }
61 syn::Fields::Named(_) => Err(syn::Error::new(
62 self.ident.span(),
63 "UntaggedEnumDeserialize: inlined struct variants are not supported -- use a named struct type instead",
64 )),
65 }
66 }
67
68 fn gen_constructor(&self) -> syn::Result<proc_macro2::TokenStream> {
69 let enum_name = &self.ident;
70 match self.fields {
71 syn::Fields::Unit => Ok(quote! { #enum_name }),
72 syn::Fields::Unnamed(fields) => {
73 if fields.unnamed.len() == 1 {
74 Ok(quote! { #enum_name(__inner) })
75 } else {
76 let elems = (0..fields.unnamed.len())
77 .map(|i| {
78 let i = syn::Index::from(i);
79 quote! { __inner.#i }
80 })
81 .collect::<Vec<proc_macro2::TokenStream>>();
82 Ok(quote! { #enum_name(#(#elems),*) })
83 }
84 }
85 syn::Fields::Named(_) => Err(syn::Error::new(
86 self.ident.span(),
87 "UntaggedEnumDeserialize: inlined struct variants are not supported -- use a named struct type instead",
88 )),
89 }
90 }
91
92 fn get_name(&self, default_rename_policy: Option<RenamePolicy>) -> String {
93 if let Some(policy) = default_rename_policy {
94 policy.apply(&self.ident)
95 } else {
96 self.ident.to_string()
97 }
98 }
99
100 fn gen_tagged_deserialize_expr(
101 &self,
102 enum_name: &syn::Ident,
103 ) -> syn::Result<proc_macro2::TokenStream> {
104 match self.fields {
105 syn::Fields::Unit => {
106 let enum_name = enum_name.to_string();
107 let variant_name = self.ident.to_string();
108
109 Ok(quote! {
110 __serde::Deserializer::deserialize_any(
111 __deserializer,
112 __serde_yaml::__private::InternallyTaggedUnitVisitor::new(
113 #enum_name,
114 #variant_name
115 )
116 )
117 })
118 }
119 syn::Fields::Unnamed(fields) => {
120 if fields.unnamed.len() == 1 {
121 let ty = &fields.unnamed[0].ty;
122
123 Ok(quote! {
124 <#ty as __serde::Deserialize>::deserialize(__deserializer)
125 })
126 } else {
127 Err(syn::Error::new(
128 self.ident.span(),
129 "UntaggedEnumDeserialize: tuple variants are not allowed in internally tagged enums",
130 ))
131 }
132 }
133 syn::Fields::Named(_) => Err(syn::Error::new(
134 self.ident.span(),
135 "UntaggedEnumDeserialize: inlined struct variants are not supported -- use a named struct type instead",
136 )),
137 }
138 }
139
140 fn gen_tagged_deserialize_arm(
141 &self,
142 enum_name: &syn::Ident,
143 default_rename_policy: Option<RenamePolicy>,
144 ) -> syn::Result<proc_macro2::TokenStream> {
145 let expr = self.gen_tagged_deserialize_expr(enum_name)?;
146 let constructor = self.gen_constructor()?;
147 let tag_name = if let Some(policy) = default_rename_policy {
148 policy.apply(&self.ident)
149 } else {
150 self.ident.to_string()
151 };
152
153 let block = quote! {
154 Some(#tag_name) => {
155 let __inner = #expr.map_err(|e| {
156 __serde::de::Error::custom(e)
157 })?;
158 return Ok(#enum_name::#constructor);
159 }
160 };
161
162 Ok(block)
163 }
164
165 fn gen_untagged_deserialize_block(&self) -> syn::Result<proc_macro2::TokenStream> {
166 let type_name = self.gen_untagged_type_name()?;
167
168 let block = quote! {
169 __unused_keys.clear();
170 let __inner = {
171 let mut collect_unused_keys =
172 |path: __serde_yaml::Path<'_>, key: &__serde_yaml::Value, value: &__serde_yaml::Value| {
173 __unused_keys.push((path.to_owned_path(), key.clone(), value.clone()));
174 };
175
176 #type_name::deserialize(__state.get_deserializer(Some(&mut collect_unused_keys)))
177 };
178 };
179
180 Ok(block)
181 }
182
183 fn gen_constructor_block(
184 &self,
185 enum_name: &syn::Ident,
186 ) -> syn::Result<proc_macro2::TokenStream> {
187 let constructor = self.gen_constructor()?;
188
189 let block = quote! {
190 if let Ok(__inner) = __inner {
191 if let Some(mut __callback) = __unused_key_callback {
192 for (path, key, value) in __unused_keys.iter() {
193 __callback(*path.as_path(), key, value);
194 }
195 }
196 return Ok(#enum_name::#constructor);
197 }
198 };
199
200 Ok(block)
201 }
202}
203
204#[allow(clippy::enum_variant_names)]
205#[derive(Debug, Clone, Copy, PartialEq, Eq)]
206enum RenamePolicy {
207 SnakeCase,
209 CamelCase,
211 LowerCase,
213 UpperCase,
215 PascalCase,
217 KebabCase,
219}
220
221impl FromStr for RenamePolicy {
222 type Err = syn::Error;
223
224 fn from_str(s: &str) -> Result<Self, Self::Err> {
225 match s {
226 "snake_case" => Ok(RenamePolicy::SnakeCase),
227 "camelCase" => Ok(RenamePolicy::CamelCase),
228 "lowercase" => Ok(RenamePolicy::LowerCase),
229 "UPPERCASE" => Ok(RenamePolicy::UpperCase),
230 "PascalCase" => Ok(RenamePolicy::PascalCase),
231 "kebab-case" => Ok(RenamePolicy::KebabCase),
232 _ => Err(syn::Error::new(
233 proc_macro2::Span::call_site(),
234 format!("Unknown rename policy: {s}"),
235 )),
236 }
237 }
238}
239
240impl RenamePolicy {
241 fn apply(&self, ident: &syn::Ident) -> String {
242 match self {
243 RenamePolicy::SnakeCase => ident.to_string().to_snake_case(),
244 RenamePolicy::CamelCase => ident.to_string().to_lower_camel_case(),
245 RenamePolicy::LowerCase => ident.to_string().to_lowercase(),
246 RenamePolicy::UpperCase => ident.to_string().to_uppercase(),
247 RenamePolicy::PascalCase => ident.to_string().to_pascal_case(),
248 RenamePolicy::KebabCase => ident.to_string().to_kebab_case(),
249 }
250 }
251}
252
253struct EnumDef<'a> {
254 ident: syn::Ident,
255 generics: &'a syn::Generics,
256 variants: Vec<Variant<'a>>,
257 tag: Option<String>,
258 rename_all: Option<RenamePolicy>,
259}
260
261impl<'a> EnumDef<'a> {
262 pub fn try_from_ast(input: &'a DeriveInput) -> syn::Result<Self> {
263 let syn::Data::Enum(data_enum) = &input.data else {
265 return Err(syn::Error::new(
266 input.span(),
267 "UntaggedEnumDeserialize: can only be derived for enums",
268 ));
269 };
270
271 let has_untagged_attr = input.attrs.iter().any(|attr| {
273 if !attr.path().is_ident("serde") {
274 return false;
275 }
276 if let Ok(syn::Expr::Path(expr_path)) = attr.parse_args() {
277 return expr_path.path.is_ident("untagged");
278 }
279 false
280 });
281 let tag_attr = input.attrs.iter().find_map(|attr| {
283 if !attr.path().is_ident("serde") {
284 return None;
285 }
286 let Ok(syn::Expr::Assign(expr)) = attr.parse_args() else {
287 return None;
288 };
289 let syn::Expr::Path(expr_path) = *expr.left else {
290 return None;
291 };
292 if !expr_path.path.is_ident("tag") {
293 return None;
294 }
295
296 match *expr.right {
297 syn::Expr::Lit(lit) => {
298 match lit.lit {
299 syn::Lit::Str(lit) => Some(lit.value()),
300 _ => None, }
302 }
303 _ => None,
304 }
305 });
306
307 if !has_untagged_attr && tag_attr.is_none() {
308 return Err(syn::Error::new(
309 input.span(),
310 "UntaggedEnumDeserialize: can only be derived for enums with #[serde(untagged)] or #[serde(tag = \"...\")] attributes",
311 ));
312 }
313
314 let rename_all_attr = input.attrs.iter().find_map(|attr| {
316 if !attr.path().is_ident("serde") {
317 return None;
318 }
319 let Ok(syn::Expr::Assign(expr)) = attr.parse_args() else {
320 return None;
321 };
322 let syn::Expr::Path(expr_path) = *expr.left else {
323 return None;
324 };
325 if !expr_path.path.is_ident("rename_all") {
326 return None;
327 }
328
329 match *expr.right {
330 syn::Expr::Lit(lit) => {
331 match lit.lit {
332 syn::Lit::Str(lit) => Some(lit.value()),
333 _ => None, }
335 }
336 _ => None,
337 }
338 });
339 let rename_all = rename_all_attr
340 .map(|a| RenamePolicy::from_str(a.as_str()))
341 .transpose()?;
342
343 for param in &input.generics.params {
345 if let syn::GenericParam::Lifetime(lifetime_param) = param {
346 return Err(syn::Error::new(
347 lifetime_param.lifetime.span(),
348 "UntaggedEnumDeserialize: borrowed lifetimes are not supported",
349 ));
350 }
351 }
352
353 let ident = input.ident.clone();
354 let generics = &input.generics;
355 let variants = data_enum
356 .variants
357 .iter()
358 .map(Variant::try_from_ast)
359 .collect::<syn::Result<Vec<_>>>()?;
360 Ok(EnumDef {
361 ident,
362 generics,
363 variants,
364 tag: tag_attr,
365 rename_all,
366 })
367 }
368
369 fn build_impl_generics(&self) -> syn::Generics {
370 let mut generics = self.generics.clone();
371 generics
373 .params
374 .push(syn::GenericParam::Lifetime(syn::LifetimeParam {
375 attrs: Vec::new(),
376 lifetime: syn::Lifetime::new("'de", self.ident.span()),
377 colon_token: None,
378 bounds: syn::punctuated::Punctuated::new(),
379 }));
380
381 for param in &mut generics.params {
384 if let syn::GenericParam::Type(ty_param) = param {
385 ty_param
386 .bounds
387 .push(syn::parse_quote!(__serde::de::DeserializeOwned));
388 }
389 }
390
391 generics
392 }
393
394 fn gen_untagged_impl(&self) -> syn::Result<proc_macro2::TokenStream> {
395 let enum_name = &self.ident;
396 let generics = self.build_impl_generics();
397 let (impl_generics, _, where_clause) = generics.split_for_impl();
398 let (_, ty_generics, _) = self.generics.split_for_impl();
399
400 let mut variant_blocks = Vec::new();
401 for variant in &self.variants {
402 let deserialize_block = variant.gen_untagged_deserialize_block()?;
403 let constructor_block = variant.gen_constructor_block(enum_name)?;
404 variant_blocks.push(quote! {
405 #deserialize_block
406 #constructor_block
407 });
408 }
409
410 let err_message = format!("data did not match any variant of untagged enum {enum_name}");
411
412 Ok(quote! {
413 #[automatically_derived]
414 impl #impl_generics __serde::Deserialize<'de> for #enum_name #ty_generics #where_clause {
415 fn deserialize<__D>(deserializer: __D) -> Result<Self, __D::Error>
416 where
417 __D: __serde::de::Deserializer<'de>,
418 {
419 let mut __state = __serde_yaml::value::extract_reusable_deserializer_state(deserializer)?;
420 let __unused_key_callback = __state.take_unused_key_callback();
421 let mut __unused_keys = vec![];
422
423 #( #variant_blocks )*
424
425 Err(__serde::de::Error::custom(#err_message))
426 }
427 }
428 })
429 }
430
431 fn gen_internally_tagged_impl(&self) -> syn::Result<proc_macro2::TokenStream> {
432 let enum_name = &self.ident;
433 let tag_key = self.tag.as_ref().expect("Expected tag key");
434 let generics = self.build_impl_generics();
435 let (impl_generics, _, where_clause) = generics.split_for_impl();
436 let (_, ty_generics, _) = self.generics.split_for_impl();
437
438 let variant_arms = self
439 .variants
440 .iter()
441 .map(|variant| variant.gen_tagged_deserialize_arm(enum_name, self.rename_all))
442 .collect::<syn::Result<Vec<_>>>()?;
443 let variant_names = self
444 .variants
445 .iter()
446 .map(|variant| variant.get_name(self.rename_all))
447 .collect::<Vec<_>>();
448
449 Ok(quote! {
450 #[automatically_derived]
451 impl #impl_generics __serde::Deserialize<'de> for #enum_name #ty_generics #where_clause {
452 fn deserialize<__D>(deserializer: __D) -> Result<Self, __D::Error>
453 where
454 __D: __serde::de::Deserializer<'de>,
455 {
456 let (__tag, mut __state) = __serde_yaml::value::extract_tag_and_deserializer_state(deserializer, #tag_key)?;
457 let __deserializer = __state.get_owned_deserializer();
458
459 match __tag.as_str() {
460 #( #variant_arms )*
461 Some(tag) => {
462 return Err(__serde::de::Error::unknown_variant(
463 tag,
464 &[ #( #variant_names ),* ]
465 ));
466 }
467 None => {
468 return Err(__serde::de::Error::invalid_value(
469 __tag.unexpected(),
470 &"a valid tag for internally tagged enum"
471 ));
472 }
473 }
474 }
475 }
476 })
477 }
478
479 fn gen_deserialize_impl(&self) -> syn::Result<proc_macro2::TokenStream> {
480 match self.tag {
481 Some(_) => self.gen_internally_tagged_impl(),
482 None => self.gen_untagged_impl(),
483 }
484 }
485}
486
487fn expand_derive_deserialize(
488 input: &mut syn::DeriveInput,
489) -> syn::Result<proc_macro2::TokenStream> {
490 let enum_def = EnumDef::try_from_ast(input)?;
491 let deserialize_impl = enum_def.gen_deserialize_impl()?;
492
493 let block = quote! {
494 const _: () = {
495 #[allow(unused_extern_crates, clippy::useless_attribute)]
496 extern crate dbt_yaml as __serde_yaml;
497 #[allow(unused_extern_crates, clippy::useless_attribute)]
498 extern crate serde as __serde;
499 #deserialize_impl
500 };
501 };
502
503 Ok(block)
504}
505
506#[proc_macro_derive(UntaggedEnumDeserialize, attributes(serde))]
507pub fn derive_deserialize(input: TokenStream) -> TokenStream {
508 let mut input = parse_macro_input!(input as DeriveInput);
509
510 expand_derive_deserialize(&mut input)
511 .unwrap_or_else(syn::Error::into_compile_error)
512 .into()
513}