cynic_codegen/enum_derive/
mod.rs

1use {
2    proc_macro2::{Span, TokenStream},
3    std::collections::BTreeMap,
4};
5
6use crate::{
7    error::Errors,
8    idents::RenameAll,
9    schema::{
10        types::{EnumType, EnumValue},
11        Schema, Unvalidated,
12    },
13};
14
15pub(crate) mod input;
16
17pub use input::EnumDeriveInput;
18use {
19    crate::suggestions::{format_guess, guess_field},
20    input::EnumDeriveVariant,
21};
22
23pub fn enum_derive(ast: &syn::DeriveInput) -> Result<TokenStream, syn::Error> {
24    use {darling::FromDeriveInput, syn::spanned::Spanned};
25
26    let enum_span = ast.span();
27
28    match EnumDeriveInput::from_derive_input(ast) {
29        Ok(input) => {
30            let schema = Schema::new(input.schema_input()?);
31
32            enum_derive_impl(input, &schema, enum_span).or_else(|e| Ok(e.to_compile_errors()))
33        }
34        Err(e) => Ok(e.write_errors()),
35    }
36}
37
38pub fn enum_derive_impl(
39    input: EnumDeriveInput,
40    schema: &Schema<'_, Unvalidated>,
41    enum_span: Span,
42) -> Result<TokenStream, Errors> {
43    use quote::quote;
44
45    let enum_def = schema
46        .lookup::<EnumType<'_>>(&input.graphql_type_name())
47        .map_err(|e| syn::Error::new(input.graphql_type_span(), e))?;
48
49    let rename_all = input.rename_all.unwrap_or(RenameAll::ScreamingSnakeCase);
50
51    input.validate()?;
52
53    if let darling::ast::Data::Enum(variants) = &input.data {
54        let fallback = variants.iter().find(|variant| *variant.fallback);
55
56        if input.non_exhaustive && fallback.is_none() {
57            return Err(syn::Error::new(
58                enum_span,
59                "Enum marked as non-exhaustive must have a fallback variant".to_string(),
60            )
61            .into());
62        }
63
64        let pairs = match join_variants(
65            variants.iter().map(|variant| variant.as_ref()),
66            &enum_def,
67            &input.ident.to_string(),
68            rename_all,
69            !input.non_exhaustive,
70            &enum_span,
71        ) {
72            Ok(pairs) => pairs,
73            Err(error_tokens) => return Ok(error_tokens),
74        };
75
76        let graphql_type_name = proc_macro2::Literal::string(&input.graphql_type_name());
77        let enum_marker_ident = enum_def.marker_ident().to_rust_ident();
78
79        let string_literals: Vec<_> = pairs
80            .iter()
81            .map(|(_, value)| value.name.to_literal())
82            .collect();
83
84        let variants: Vec<_> = pairs.iter().map(|(variant, _)| &variant.ident).collect();
85        let variant_indexes: Vec<_> = pairs
86            .iter()
87            .enumerate()
88            .map(|(i, _)| {
89                proc_macro2::Literal::u32_suffixed(
90                    i.try_into().expect("an enum with less than 2^32 variants"),
91                )
92            })
93            .collect();
94
95        let schema_module = input.schema_module();
96        let ident = input.ident;
97
98        let fallback_ser_branch = match fallback {
99            None => quote! {},
100            Some(fallback) if fallback.fields.fields.is_empty() => {
101                let fallback_ident = &fallback.ident;
102                quote! {
103                    #ident::#fallback_ident => {
104                        use cynic::serde::ser::Error;
105                        Err(__S::Error::custom("cynic can't serialize the fallback variant of an enum unless it has a field"))
106                    }
107                }
108            }
109            Some(fallback) => {
110                let fallback_ident = &fallback.ident;
111                quote! {
112                    #ident::#fallback_ident(value) => {
113                        serializer.serialize_str(value)
114                    }
115                }
116            }
117        };
118
119        let fallback_deser_branch = match fallback {
120            None => quote! {
121                unknown => {
122                    const VARIANTS: &'static [&'static str] = &[#(#string_literals),*];
123                    Err(cynic::serde::de::Error::unknown_variant(unknown, VARIANTS))
124                }
125            },
126            Some(fallback) if fallback.fields.fields.is_empty() => {
127                let fallback_ident = &fallback.ident;
128                quote! {
129                     _ => {
130                        Ok(#ident::#fallback_ident)
131                     }
132                }
133            }
134            Some(fallback) => {
135                let fallback_ident = &fallback.ident;
136                quote! {
137                     _ => {
138                        Ok(#ident::#fallback_ident(desered_string))
139                     }
140                }
141            }
142        };
143
144        Ok(quote! {
145            #[automatically_derived]
146            impl cynic::Enum for #ident {
147                type SchemaType = #schema_module::#enum_marker_ident;
148            }
149
150            #[automatically_derived]
151            impl cynic::serde::Serialize for #ident {
152                fn serialize<__S>(&self, serializer: __S) -> Result<__S::Ok, __S::Error>
153                where
154                    __S: cynic::serde::Serializer {
155                        match self {
156                            #(
157                                #ident::#variants => serializer.serialize_unit_variant(#graphql_type_name, #variant_indexes, #string_literals),
158                            )*
159                            #fallback_ser_branch
160                        }
161                    }
162            }
163
164            #[automatically_derived]
165            impl<'de> cynic::serde::Deserialize<'de> for #ident {
166                fn deserialize<__D>(deserializer: __D) -> Result<Self, __D::Error>
167                where
168                    __D: cynic::serde::Deserializer<'de>,
169                {
170                    let desered_string = <String as cynic::serde::Deserialize>::deserialize(deserializer)?;
171                    match desered_string.as_ref() {
172                        #(
173                            #string_literals => Ok(#ident::#variants),
174                        )*
175                        #fallback_deser_branch
176                    }
177                }
178            }
179
180            cynic::impl_coercions!(#ident, #schema_module::#enum_marker_ident);
181
182            #[automatically_derived]
183            impl #schema_module::variable::Variable for #ident {
184                const TYPE: cynic::variables::VariableType = cynic::variables::VariableType::Named(#graphql_type_name);
185            }
186        })
187    } else {
188        Err(syn::Error::new(
189            enum_span,
190            "Enum can only be derived from an enum".to_string(),
191        )
192        .into())
193    }
194}
195
196fn join_variants<'a>(
197    variants: impl IntoIterator<Item = &'a EnumDeriveVariant>,
198    enum_def: &'a EnumType<'a>,
199    enum_name: &str,
200    rename_all: RenameAll,
201    exhaustive: bool,
202    enum_span: &Span,
203) -> Result<Vec<(&'a EnumDeriveVariant, &'a EnumValue<'a>)>, TokenStream> {
204    let mut has_fallback = false;
205    let mut map = BTreeMap::new();
206    for variant in variants {
207        if *variant.fallback {
208            has_fallback = true;
209            // We can't join up a fallback as it has no corresponding GQL value.
210            // We handle them separately.
211            continue;
212        }
213        let graphql_ident = variant.graphql_ident(rename_all);
214        map.insert(
215            graphql_ident.graphql_name(),
216            (Some(variant), enum_def.value(&graphql_ident)),
217        );
218    }
219
220    for value in &enum_def.values {
221        if !map.contains_key(value.name.as_str()) {
222            map.insert(value.name.as_str().to_owned(), (None, Some(value)));
223        }
224    }
225
226    let mut missing_variants = vec![];
227    let mut errors = TokenStream::new();
228    for (graphql_name, value) in map.iter() {
229        match value {
230            (None, Some(enum_value)) => missing_variants.push(enum_value.name.as_str()),
231            (Some(variant), None) => {
232                let candidates = map
233                    .values()
234                    .flat_map(|v| v.1.map(|input| input.name.as_str()));
235                let guess_field = guess_field(candidates, graphql_name);
236                errors.extend(
237                    syn::Error::new(
238                        variant.ident.span(),
239                        format!(
240                            "Could not find a variant {} in the GraphQL enum {}.{}",
241                            graphql_name,
242                            enum_name,
243                            format_guess(guess_field)
244                        ),
245                    )
246                    .to_compile_error(),
247                )
248            }
249            _ => (),
250        }
251    }
252    if !missing_variants.is_empty() && (exhaustive || !has_fallback) {
253        let missing_variants_string = missing_variants.join(", ");
254        errors.extend(
255            syn::Error::new(
256                *enum_span,
257                format!("Missing variants: {}", missing_variants_string),
258            )
259            .to_compile_error(),
260        )
261    }
262    if !errors.is_empty() {
263        return Err(errors);
264    }
265
266    Ok(map
267        .into_iter()
268        .filter_map(|(_, (a, b))| Some((a?, b.unwrap())))
269        .collect())
270}
271
272#[cfg(test)]
273mod tests {
274    use {
275        assert_matches::assert_matches, darling::util::SpannedValue, rstest::rstest,
276        std::collections::HashSet, syn::parse_quote,
277    };
278
279    use {super::*, crate::schema::FieldName};
280
281    #[rstest(
282        enum_variant_1,
283        enum_variant_2,
284        enum_value_1,
285        enum_value_2,
286        rename_rule,
287        case(
288            "Cheesecake",
289            "IceCream",
290            "CHEESECAKE",
291            "ICE_CREAM",
292            RenameAll::ScreamingSnakeCase
293        ),
294        case("CHEESECAKE", "ICE_CREAM", "CHEESECAKE", "ICE_CREAM", RenameAll::None)
295    )]
296    fn join_variants_happy_path(
297        enum_variant_1: &str,
298        enum_variant_2: &str,
299        enum_value_1: &str,
300        enum_value_2: &str,
301        rename_rule: RenameAll,
302    ) {
303        let variants = vec![
304            EnumDeriveVariant {
305                ident: proc_macro2::Ident::new(enum_variant_1, Span::call_site()),
306                rename: None,
307                fallback: Default::default(),
308                fields: darling::ast::Style::Unit.into(),
309            },
310            EnumDeriveVariant {
311                ident: proc_macro2::Ident::new(enum_variant_2, Span::call_site()),
312                rename: None,
313                fallback: Default::default(),
314                fields: darling::ast::Style::Unit.into(),
315            },
316        ];
317        let mut gql_enum = EnumType {
318            name: "Desserts".into(),
319            values: vec![],
320        };
321        gql_enum.values.push(EnumValue {
322            name: FieldName::new(enum_value_1),
323        });
324        gql_enum.values.push(EnumValue {
325            name: FieldName::new(enum_value_2),
326        });
327
328        let result = join_variants(
329            &variants,
330            &gql_enum,
331            "Desserts",
332            rename_rule,
333            true,
334            &Span::call_site(),
335        );
336
337        assert_matches!(result, Ok(_));
338        let pairs = result.unwrap();
339
340        assert_eq!(pairs.len(), 2);
341
342        let names: HashSet<_> = pairs
343            .iter()
344            .map(|(variant, ty)| (variant.ident.to_string(), ty.name.clone()))
345            .collect();
346
347        assert_eq!(
348            names,
349            maplit::hashset! {(enum_variant_1.into(), FieldName::new(enum_value_1)), (enum_variant_2.into(), FieldName::new(enum_value_2))}
350        );
351    }
352
353    #[test]
354    fn join_variants_with_field_rename() {
355        let variants = vec![
356            EnumDeriveVariant {
357                ident: proc_macro2::Ident::new("Cheesecake", Span::call_site()),
358                rename: None,
359                fallback: Default::default(),
360                fields: darling::ast::Style::Unit.into(),
361            },
362            EnumDeriveVariant {
363                ident: proc_macro2::Ident::new("IceCream", Span::call_site()),
364                rename: Some(SpannedValue::new("iced-goodness".into(), Span::call_site())),
365                fallback: Default::default(),
366                fields: darling::ast::Style::Unit.into(),
367            },
368        ];
369        let mut gql_enum = EnumType {
370            name: "Desserts".into(),
371            values: vec![],
372        };
373        gql_enum.values.push(EnumValue {
374            name: FieldName::new("CHEESECAKE"),
375        });
376        gql_enum.values.push(EnumValue {
377            name: FieldName::new("iced-goodness"),
378        });
379
380        let result = join_variants(
381            &variants,
382            &gql_enum,
383            "Desserts",
384            RenameAll::ScreamingSnakeCase,
385            true,
386            &Span::call_site(),
387        );
388
389        assert_matches!(result, Ok(_));
390        let pairs = result.unwrap();
391
392        assert_eq!(pairs.len(), 2);
393
394        let names: HashSet<_> = pairs
395            .iter()
396            .map(|(variant, ty)| (variant.ident.to_string(), ty.name.clone()))
397            .collect();
398
399        assert_eq!(
400            names,
401            maplit::hashset! {("Cheesecake".into(), FieldName::new("CHEESECAKE")), ("IceCream".into(), FieldName::new("iced-goodness"))}
402        );
403    }
404
405    #[test]
406    fn join_variants_missing_rust_variant() {
407        let variants = vec![EnumDeriveVariant {
408            ident: proc_macro2::Ident::new("CHEESECAKE", Span::call_site()),
409            rename: None,
410            fallback: Default::default(),
411            fields: darling::ast::Style::Unit.into(),
412        }];
413        let mut gql_enum = EnumType {
414            name: "Desserts".into(),
415            values: vec![],
416        };
417        gql_enum.values.push(EnumValue {
418            name: FieldName::new("CHEESECAKE"),
419        });
420        gql_enum.values.push(EnumValue {
421            name: FieldName::new("ICE_CREAM"),
422        });
423
424        let result = join_variants(
425            &variants,
426            &gql_enum,
427            "Desserts",
428            RenameAll::None,
429            true,
430            &Span::call_site(),
431        );
432
433        assert_matches!(result, Err(_));
434    }
435
436    #[test]
437    fn join_variants_missing_rust_variant_in_a_non_exhaustive_enum() {
438        let variants = vec![
439            EnumDeriveVariant {
440                ident: proc_macro2::Ident::new("FIRST", Span::call_site()),
441                rename: None,
442                fallback: Default::default(),
443                fields: darling::ast::Style::Unit.into(),
444            },
445            EnumDeriveVariant {
446                ident: proc_macro2::Ident::new("FALLBACK", Span::call_site()),
447                rename: None,
448                fallback: SpannedValue::new(true, Span::call_site()),
449                fields: darling::ast::Style::Unit.into(),
450            },
451        ];
452        let mut gql_enum = EnumType {
453            name: "Enum".into(),
454            values: vec![],
455        };
456        gql_enum.values.push(EnumValue {
457            name: FieldName::new("FIRST"),
458        });
459        gql_enum.values.push(EnumValue {
460            name: FieldName::new("SECOND"),
461        });
462
463        let result = join_variants(
464            &variants,
465            &gql_enum,
466            "Enum",
467            RenameAll::None,
468            false,
469            &Span::call_site(),
470        );
471
472        assert_matches!(result, Ok(_));
473    }
474
475    #[test]
476    fn join_variants_missing_gql_variant() {
477        let variants = vec![EnumDeriveVariant {
478            ident: proc_macro2::Ident::new("CHEESECAKE", Span::call_site()),
479            rename: None,
480            fallback: Default::default(),
481            fields: darling::ast::Style::Unit.into(),
482        }];
483        let mut gql_enum = EnumType {
484            name: "Desserts".into(),
485            values: vec![],
486        };
487        gql_enum.values.push(EnumValue {
488            name: FieldName::new("ICE_CREAM"),
489        });
490
491        let result = join_variants(
492            &variants,
493            &gql_enum,
494            "Desserts",
495            RenameAll::None,
496            true,
497            &Span::call_site(),
498        );
499
500        assert_matches!(result, Err(_));
501    }
502
503    #[rstest(input => [
504        parse_quote!(
505            #[cynic(
506                schema_path = "../schemas/test_cases.graphql",
507            )]
508            enum States {
509                Open,
510                Closed,
511                Deleted
512            }
513        ),
514    ])]
515    fn snapshot_enum_derive(input: syn::DeriveInput) {
516        let tokens = enum_derive(&input).unwrap();
517
518        insta::assert_snapshot!(format_code(format!("{}", tokens)));
519    }
520
521    fn format_code(input: String) -> String {
522        use std::io::Write;
523
524        let mut cmd = std::process::Command::new("rustfmt")
525            .stdin(std::process::Stdio::piped())
526            .stdout(std::process::Stdio::piped())
527            .stderr(std::process::Stdio::inherit())
528            .spawn()
529            .expect("failed to execute rustfmt");
530
531        write!(cmd.stdin.as_mut().unwrap(), "{}", input).unwrap();
532
533        std::str::from_utf8(&cmd.wait_with_output().unwrap().stdout)
534            .unwrap()
535            .to_owned()
536    }
537}