1use {
2 proc_macro2::{Span, TokenStream},
3 std::collections::BTreeMap,
4};
5
6use crate::{
7 error::Errors,
8 idents::RenameAll,
9 schema::{
10 types::{EnumType, EnumValue},
11 Schema, Unvalidated,
12 },
13};
14
15pub(crate) mod input;
16
17pub use input::EnumDeriveInput;
18use {
19 crate::suggestions::{format_guess, guess_field},
20 input::EnumDeriveVariant,
21};
22
23pub fn enum_derive(ast: &syn::DeriveInput) -> Result<TokenStream, syn::Error> {
24 use {darling::FromDeriveInput, syn::spanned::Spanned};
25
26 let enum_span = ast.span();
27
28 match EnumDeriveInput::from_derive_input(ast) {
29 Ok(input) => {
30 let schema = Schema::new(input.schema_input()?);
31
32 enum_derive_impl(input, &schema, enum_span).or_else(|e| Ok(e.to_compile_errors()))
33 }
34 Err(e) => Ok(e.write_errors()),
35 }
36}
37
38pub fn enum_derive_impl(
39 input: EnumDeriveInput,
40 schema: &Schema<'_, Unvalidated>,
41 enum_span: Span,
42) -> Result<TokenStream, Errors> {
43 use quote::quote;
44
45 let enum_def = schema
46 .lookup::<EnumType<'_>>(&input.graphql_type_name())
47 .map_err(|e| syn::Error::new(input.graphql_type_span(), e))?;
48
49 let rename_all = input.rename_all.unwrap_or(RenameAll::ScreamingSnakeCase);
50
51 input.validate()?;
52
53 if let darling::ast::Data::Enum(variants) = &input.data {
54 let fallback = variants.iter().find(|variant| *variant.fallback);
55
56 if input.non_exhaustive && fallback.is_none() {
57 return Err(syn::Error::new(
58 enum_span,
59 "Enum marked as non-exhaustive must have a fallback variant".to_string(),
60 )
61 .into());
62 }
63
64 let pairs = match join_variants(
65 variants.iter().map(|variant| variant.as_ref()),
66 &enum_def,
67 &input.ident.to_string(),
68 rename_all,
69 !input.non_exhaustive,
70 &enum_span,
71 ) {
72 Ok(pairs) => pairs,
73 Err(error_tokens) => return Ok(error_tokens),
74 };
75
76 let graphql_type_name = proc_macro2::Literal::string(&input.graphql_type_name());
77 let enum_marker_ident = enum_def.marker_ident().to_rust_ident();
78
79 let string_literals: Vec<_> = pairs
80 .iter()
81 .map(|(_, value)| value.name.to_literal())
82 .collect();
83
84 let variants: Vec<_> = pairs.iter().map(|(variant, _)| &variant.ident).collect();
85 let variant_indexes: Vec<_> = pairs
86 .iter()
87 .enumerate()
88 .map(|(i, _)| {
89 proc_macro2::Literal::u32_suffixed(
90 i.try_into().expect("an enum with less than 2^32 variants"),
91 )
92 })
93 .collect();
94
95 let schema_module = input.schema_module();
96 let ident = input.ident;
97
98 let fallback_ser_branch = match fallback {
99 None => quote! {},
100 Some(fallback) if fallback.fields.fields.is_empty() => {
101 let fallback_ident = &fallback.ident;
102 quote! {
103 #ident::#fallback_ident => {
104 use cynic::serde::ser::Error;
105 Err(__S::Error::custom("cynic can't serialize the fallback variant of an enum unless it has a field"))
106 }
107 }
108 }
109 Some(fallback) => {
110 let fallback_ident = &fallback.ident;
111 quote! {
112 #ident::#fallback_ident(value) => {
113 serializer.serialize_str(value)
114 }
115 }
116 }
117 };
118
119 let fallback_deser_branch = match fallback {
120 None => quote! {
121 unknown => {
122 const VARIANTS: &'static [&'static str] = &[#(#string_literals),*];
123 Err(cynic::serde::de::Error::unknown_variant(unknown, VARIANTS))
124 }
125 },
126 Some(fallback) if fallback.fields.fields.is_empty() => {
127 let fallback_ident = &fallback.ident;
128 quote! {
129 _ => {
130 Ok(#ident::#fallback_ident)
131 }
132 }
133 }
134 Some(fallback) => {
135 let fallback_ident = &fallback.ident;
136 quote! {
137 _ => {
138 Ok(#ident::#fallback_ident(desered_string))
139 }
140 }
141 }
142 };
143
144 Ok(quote! {
145 #[automatically_derived]
146 impl cynic::Enum for #ident {
147 type SchemaType = #schema_module::#enum_marker_ident;
148 }
149
150 #[automatically_derived]
151 impl cynic::serde::Serialize for #ident {
152 fn serialize<__S>(&self, serializer: __S) -> Result<__S::Ok, __S::Error>
153 where
154 __S: cynic::serde::Serializer {
155 match self {
156 #(
157 #ident::#variants => serializer.serialize_unit_variant(#graphql_type_name, #variant_indexes, #string_literals),
158 )*
159 #fallback_ser_branch
160 }
161 }
162 }
163
164 #[automatically_derived]
165 impl<'de> cynic::serde::Deserialize<'de> for #ident {
166 fn deserialize<__D>(deserializer: __D) -> Result<Self, __D::Error>
167 where
168 __D: cynic::serde::Deserializer<'de>,
169 {
170 let desered_string = <String as cynic::serde::Deserialize>::deserialize(deserializer)?;
171 match desered_string.as_ref() {
172 #(
173 #string_literals => Ok(#ident::#variants),
174 )*
175 #fallback_deser_branch
176 }
177 }
178 }
179
180 cynic::impl_coercions!(#ident, #schema_module::#enum_marker_ident);
181
182 #[automatically_derived]
183 impl #schema_module::variable::Variable for #ident {
184 const TYPE: cynic::variables::VariableType = cynic::variables::VariableType::Named(#graphql_type_name);
185 }
186 })
187 } else {
188 Err(syn::Error::new(
189 enum_span,
190 "Enum can only be derived from an enum".to_string(),
191 )
192 .into())
193 }
194}
195
196fn join_variants<'a>(
197 variants: impl IntoIterator<Item = &'a EnumDeriveVariant>,
198 enum_def: &'a EnumType<'a>,
199 enum_name: &str,
200 rename_all: RenameAll,
201 exhaustive: bool,
202 enum_span: &Span,
203) -> Result<Vec<(&'a EnumDeriveVariant, &'a EnumValue<'a>)>, TokenStream> {
204 let mut has_fallback = false;
205 let mut map = BTreeMap::new();
206 for variant in variants {
207 if *variant.fallback {
208 has_fallback = true;
209 continue;
212 }
213 let graphql_ident = variant.graphql_ident(rename_all);
214 map.insert(
215 graphql_ident.graphql_name(),
216 (Some(variant), enum_def.value(&graphql_ident)),
217 );
218 }
219
220 for value in &enum_def.values {
221 if !map.contains_key(value.name.as_str()) {
222 map.insert(value.name.as_str().to_owned(), (None, Some(value)));
223 }
224 }
225
226 let mut missing_variants = vec![];
227 let mut errors = TokenStream::new();
228 for (graphql_name, value) in map.iter() {
229 match value {
230 (None, Some(enum_value)) => missing_variants.push(enum_value.name.as_str()),
231 (Some(variant), None) => {
232 let candidates = map
233 .values()
234 .flat_map(|v| v.1.map(|input| input.name.as_str()));
235 let guess_field = guess_field(candidates, graphql_name);
236 errors.extend(
237 syn::Error::new(
238 variant.ident.span(),
239 format!(
240 "Could not find a variant {} in the GraphQL enum {}.{}",
241 graphql_name,
242 enum_name,
243 format_guess(guess_field)
244 ),
245 )
246 .to_compile_error(),
247 )
248 }
249 _ => (),
250 }
251 }
252 if !missing_variants.is_empty() && (exhaustive || !has_fallback) {
253 let missing_variants_string = missing_variants.join(", ");
254 errors.extend(
255 syn::Error::new(
256 *enum_span,
257 format!("Missing variants: {}", missing_variants_string),
258 )
259 .to_compile_error(),
260 )
261 }
262 if !errors.is_empty() {
263 return Err(errors);
264 }
265
266 Ok(map
267 .into_iter()
268 .filter_map(|(_, (a, b))| Some((a?, b.unwrap())))
269 .collect())
270}
271
272#[cfg(test)]
273mod tests {
274 use {
275 assert_matches::assert_matches, darling::util::SpannedValue, rstest::rstest,
276 std::collections::HashSet, syn::parse_quote,
277 };
278
279 use {super::*, crate::schema::FieldName};
280
281 #[rstest(
282 enum_variant_1,
283 enum_variant_2,
284 enum_value_1,
285 enum_value_2,
286 rename_rule,
287 case(
288 "Cheesecake",
289 "IceCream",
290 "CHEESECAKE",
291 "ICE_CREAM",
292 RenameAll::ScreamingSnakeCase
293 ),
294 case("CHEESECAKE", "ICE_CREAM", "CHEESECAKE", "ICE_CREAM", RenameAll::None)
295 )]
296 fn join_variants_happy_path(
297 enum_variant_1: &str,
298 enum_variant_2: &str,
299 enum_value_1: &str,
300 enum_value_2: &str,
301 rename_rule: RenameAll,
302 ) {
303 let variants = vec![
304 EnumDeriveVariant {
305 ident: proc_macro2::Ident::new(enum_variant_1, Span::call_site()),
306 rename: None,
307 fallback: Default::default(),
308 fields: darling::ast::Style::Unit.into(),
309 },
310 EnumDeriveVariant {
311 ident: proc_macro2::Ident::new(enum_variant_2, Span::call_site()),
312 rename: None,
313 fallback: Default::default(),
314 fields: darling::ast::Style::Unit.into(),
315 },
316 ];
317 let mut gql_enum = EnumType {
318 name: "Desserts".into(),
319 values: vec![],
320 };
321 gql_enum.values.push(EnumValue {
322 name: FieldName::new(enum_value_1),
323 });
324 gql_enum.values.push(EnumValue {
325 name: FieldName::new(enum_value_2),
326 });
327
328 let result = join_variants(
329 &variants,
330 &gql_enum,
331 "Desserts",
332 rename_rule,
333 true,
334 &Span::call_site(),
335 );
336
337 assert_matches!(result, Ok(_));
338 let pairs = result.unwrap();
339
340 assert_eq!(pairs.len(), 2);
341
342 let names: HashSet<_> = pairs
343 .iter()
344 .map(|(variant, ty)| (variant.ident.to_string(), ty.name.clone()))
345 .collect();
346
347 assert_eq!(
348 names,
349 maplit::hashset! {(enum_variant_1.into(), FieldName::new(enum_value_1)), (enum_variant_2.into(), FieldName::new(enum_value_2))}
350 );
351 }
352
353 #[test]
354 fn join_variants_with_field_rename() {
355 let variants = vec![
356 EnumDeriveVariant {
357 ident: proc_macro2::Ident::new("Cheesecake", Span::call_site()),
358 rename: None,
359 fallback: Default::default(),
360 fields: darling::ast::Style::Unit.into(),
361 },
362 EnumDeriveVariant {
363 ident: proc_macro2::Ident::new("IceCream", Span::call_site()),
364 rename: Some(SpannedValue::new("iced-goodness".into(), Span::call_site())),
365 fallback: Default::default(),
366 fields: darling::ast::Style::Unit.into(),
367 },
368 ];
369 let mut gql_enum = EnumType {
370 name: "Desserts".into(),
371 values: vec![],
372 };
373 gql_enum.values.push(EnumValue {
374 name: FieldName::new("CHEESECAKE"),
375 });
376 gql_enum.values.push(EnumValue {
377 name: FieldName::new("iced-goodness"),
378 });
379
380 let result = join_variants(
381 &variants,
382 &gql_enum,
383 "Desserts",
384 RenameAll::ScreamingSnakeCase,
385 true,
386 &Span::call_site(),
387 );
388
389 assert_matches!(result, Ok(_));
390 let pairs = result.unwrap();
391
392 assert_eq!(pairs.len(), 2);
393
394 let names: HashSet<_> = pairs
395 .iter()
396 .map(|(variant, ty)| (variant.ident.to_string(), ty.name.clone()))
397 .collect();
398
399 assert_eq!(
400 names,
401 maplit::hashset! {("Cheesecake".into(), FieldName::new("CHEESECAKE")), ("IceCream".into(), FieldName::new("iced-goodness"))}
402 );
403 }
404
405 #[test]
406 fn join_variants_missing_rust_variant() {
407 let variants = vec![EnumDeriveVariant {
408 ident: proc_macro2::Ident::new("CHEESECAKE", Span::call_site()),
409 rename: None,
410 fallback: Default::default(),
411 fields: darling::ast::Style::Unit.into(),
412 }];
413 let mut gql_enum = EnumType {
414 name: "Desserts".into(),
415 values: vec![],
416 };
417 gql_enum.values.push(EnumValue {
418 name: FieldName::new("CHEESECAKE"),
419 });
420 gql_enum.values.push(EnumValue {
421 name: FieldName::new("ICE_CREAM"),
422 });
423
424 let result = join_variants(
425 &variants,
426 &gql_enum,
427 "Desserts",
428 RenameAll::None,
429 true,
430 &Span::call_site(),
431 );
432
433 assert_matches!(result, Err(_));
434 }
435
436 #[test]
437 fn join_variants_missing_rust_variant_in_a_non_exhaustive_enum() {
438 let variants = vec![
439 EnumDeriveVariant {
440 ident: proc_macro2::Ident::new("FIRST", Span::call_site()),
441 rename: None,
442 fallback: Default::default(),
443 fields: darling::ast::Style::Unit.into(),
444 },
445 EnumDeriveVariant {
446 ident: proc_macro2::Ident::new("FALLBACK", Span::call_site()),
447 rename: None,
448 fallback: SpannedValue::new(true, Span::call_site()),
449 fields: darling::ast::Style::Unit.into(),
450 },
451 ];
452 let mut gql_enum = EnumType {
453 name: "Enum".into(),
454 values: vec![],
455 };
456 gql_enum.values.push(EnumValue {
457 name: FieldName::new("FIRST"),
458 });
459 gql_enum.values.push(EnumValue {
460 name: FieldName::new("SECOND"),
461 });
462
463 let result = join_variants(
464 &variants,
465 &gql_enum,
466 "Enum",
467 RenameAll::None,
468 false,
469 &Span::call_site(),
470 );
471
472 assert_matches!(result, Ok(_));
473 }
474
475 #[test]
476 fn join_variants_missing_gql_variant() {
477 let variants = vec![EnumDeriveVariant {
478 ident: proc_macro2::Ident::new("CHEESECAKE", Span::call_site()),
479 rename: None,
480 fallback: Default::default(),
481 fields: darling::ast::Style::Unit.into(),
482 }];
483 let mut gql_enum = EnumType {
484 name: "Desserts".into(),
485 values: vec![],
486 };
487 gql_enum.values.push(EnumValue {
488 name: FieldName::new("ICE_CREAM"),
489 });
490
491 let result = join_variants(
492 &variants,
493 &gql_enum,
494 "Desserts",
495 RenameAll::None,
496 true,
497 &Span::call_site(),
498 );
499
500 assert_matches!(result, Err(_));
501 }
502
503 #[rstest(input => [
504 parse_quote!(
505 #[cynic(
506 schema_path = "../schemas/test_cases.graphql",
507 )]
508 enum States {
509 Open,
510 Closed,
511 Deleted
512 }
513 ),
514 ])]
515 fn snapshot_enum_derive(input: syn::DeriveInput) {
516 let tokens = enum_derive(&input).unwrap();
517
518 insta::assert_snapshot!(format_code(format!("{}", tokens)));
519 }
520
521 fn format_code(input: String) -> String {
522 use std::io::Write;
523
524 let mut cmd = std::process::Command::new("rustfmt")
525 .stdin(std::process::Stdio::piped())
526 .stdout(std::process::Stdio::piped())
527 .stderr(std::process::Stdio::inherit())
528 .spawn()
529 .expect("failed to execute rustfmt");
530
531 write!(cmd.stdin.as_mut().unwrap(), "{}", input).unwrap();
532
533 std::str::from_utf8(&cmd.wait_with_output().unwrap().stdout)
534 .unwrap()
535 .to_owned()
536 }
537}