1use super::directives::FieldDirective;
2
3use {darling::util::SpannedValue, proc_macro2::Span, std::collections::HashSet};
4
5use crate::{idents::RenamableFieldIdent, schema::SchemaInput, types::CheckMode, Errors};
6
7#[derive(darling::FromDeriveInput)]
8#[darling(attributes(cynic), supports(struct_named))]
9pub struct FragmentDeriveInput {
10 pub(super) ident: proc_macro2::Ident,
11 pub(super) data: darling::ast::Data<(), RawFragmentDeriveField>,
12 pub(super) generics: syn::Generics,
13
14 #[darling(default)]
15 schema: Option<SpannedValue<String>>,
16 #[darling(default)]
17 schema_path: Option<SpannedValue<String>>,
18
19 #[darling(default, rename = "schema_module")]
20 schema_module_: Option<syn::Path>,
21
22 #[darling(default)]
23 pub graphql_type: Option<SpannedValue<String>>,
24
25 #[darling(default)]
26 pub(super) no_deserialize: bool,
27
28 #[darling(default)]
29 variables: Option<syn::Path>,
30}
31
32impl FragmentDeriveInput {
33 pub fn schema_module(&self) -> syn::Path {
34 if let Some(schema_module) = &self.schema_module_ {
35 return schema_module.clone();
36 }
37 syn::parse2(quote::quote! { schema }).unwrap()
38 }
39
40 pub fn graphql_type_name(&self) -> String {
41 self.graphql_type
42 .as_ref()
43 .map(|sp| sp.to_string())
44 .unwrap_or_else(|| self.ident.to_string())
45 }
46
47 pub fn graphql_type_span(&self) -> Span {
48 self.graphql_type
49 .as_ref()
50 .map(|val| val.span())
51 .unwrap_or_else(|| self.ident.span())
52 }
53
54 pub fn validate(&self) -> Result<Vec<FragmentDeriveField>, Errors> {
55 let data_field_is_empty = matches!(self.data.clone(), darling::ast::Data::Struct(fields) if fields.fields.is_empty());
56 if data_field_is_empty {
57 return Err(syn::Error::new(
58 self.ident.span(),
59 format!(
60 "At least one field should be selected for `{}`.",
61 self.ident
62 ),
63 )
64 .into());
65 }
66
67 let mut fields = vec![];
68 let mut errors = Errors::default();
69
70 let results = self
71 .data
72 .clone()
73 .map_struct_fields(|field| field.validate())
74 .take_struct()
75 .unwrap()
76 .into_iter();
77
78 for result in results {
79 match result {
80 Ok(field) => fields.push(field),
81 Err(error) => errors.extend(error),
82 }
83 }
84
85 if !errors.is_empty() {
86 return Err(errors);
87 }
88
89 Ok(fields)
90 }
91
92 pub fn detect_aliases(&mut self) {
93 let mut names = HashSet::new();
94 if let darling::ast::Data::Struct(fields) = &mut self.data {
95 for field in &mut fields.fields {
96 if let Some(rename) = &mut field.rename {
97 let name = rename.as_str();
98 if names.contains(name) {
99 field.alias = true.into();
100 continue;
101 }
102 names.insert(name);
103 }
104 }
105 }
106 }
107
108 pub fn variables(&self) -> Option<syn::Path> {
109 self.variables.clone()
110 }
111
112 pub fn schema_input(&self) -> Result<SchemaInput, syn::Error> {
113 match (&self.schema, &self.schema_path) {
114 (None, None) => SchemaInput::default().map_err(|e| e.into_syn_error(Span::call_site())),
115 (None, Some(path)) => SchemaInput::from_schema_path(path.as_ref())
116 .map_err(|e| e.into_syn_error(path.span())),
117 (Some(name), None) => SchemaInput::from_schema_name(name.as_ref())
118 .map_err(|e| e.into_syn_error(name.span())),
119 (Some(_), Some(path)) => Err(syn::Error::new(
120 path.span(),
121 "Only one of schema_path & schema can be provided",
122 )),
123 }
124 }
125}
126
127#[derive(darling::FromField, Clone)]
128#[darling(attributes(cynic), forward_attrs(arguments, directives))]
129pub struct RawFragmentDeriveField {
130 pub(super) ident: Option<proc_macro2::Ident>,
131 pub(super) ty: syn::Type,
132
133 pub(super) attrs: Vec<syn::Attribute>,
134
135 #[darling(default)]
136 pub(super) flatten: SpannedValue<bool>,
137
138 #[darling(default)]
139 pub(super) recurse: Option<SpannedValue<u8>>,
140
141 #[darling(default)]
142 pub(super) spread: SpannedValue<bool>,
143
144 #[darling(default)]
145 rename: Option<SpannedValue<String>>,
146
147 #[darling(default)]
148 alias: SpannedValue<bool>,
149
150 #[darling(default)]
151 pub(super) feature: Option<SpannedValue<String>>,
152
153 #[darling(default)]
154 pub(super) default: SpannedValue<bool>,
155}
156
157pub struct FragmentDeriveField {
158 pub(super) raw_field: RawFragmentDeriveField,
159
160 pub(super) directives: Vec<super::directives::FieldDirective>,
161}
162
163impl RawFragmentDeriveField {
164 pub fn validate(self) -> Result<FragmentDeriveField, Errors> {
165 if *self.flatten && self.recurse.is_some() {
166 return Err(syn::Error::new(
167 self.recurse.as_ref().unwrap().span(),
168 "A field can't be recurse if it's being flattened",
169 )
170 .into());
171 }
172
173 if *self.flatten && *self.spread {
174 return Err(syn::Error::new(
175 self.flatten.span(),
176 "A field can't be flattened if it's also being spread",
177 )
178 .into());
179 }
180
181 if *self.spread && self.recurse.is_some() {
182 return Err(syn::Error::new(
183 self.recurse.as_ref().unwrap().span(),
184 "A field can't be recurse if it's being spread",
185 )
186 .into());
187 }
188
189 if *self.alias && self.rename.is_none() {
190 return Err(syn::Error::new(
191 self.alias.span(),
192 "You can only alias a renamed field. Try removing `alias` or adding a rename",
193 )
194 .into());
195 }
196
197 if *self.default && *self.spread {
198 return Err(syn::Error::new(
199 self.default.span(),
200 "A field can't be defaulted if it's also being spread",
201 )
202 .into());
203 }
204
205 if *self.default && self.recurse.is_some() {
206 return Err(syn::Error::new(
207 self.recurse.unwrap().span(),
208 "A field can't be recurse if it's also being defaulted",
209 )
210 .into());
211 }
212
213 if *self.default && *self.flatten {
214 return Err(syn::Error::new(
215 self.default.span(),
216 "A field can't be defaulted if it's being flattened",
217 )
218 .into());
219 }
220
221 let directives = super::directives::directives_from_field_attrs(&self.attrs)?;
222 let skippable = directives.iter().any(|directive| {
223 matches!(
224 directive,
225 FieldDirective::Include(_) | FieldDirective::Skip(_)
226 )
227 });
228
229 if skippable {
230 if *self.spread {
231 return Err(syn::Error::new(
232 self.spread.span(),
233 "spread can't currently be used on fields with skip or include directives",
234 )
235 .into());
236 } else if *self.flatten {
237 return Err(syn::Error::new(
238 self.flatten.span(),
239 "flatten can't currently be used on fields with skip or include directives",
240 )
241 .into());
242 } else if let Some(recurse) = self.recurse {
243 return Err(syn::Error::new(
244 recurse.span(),
245 "recurse can't currently be used on fields with skip or include directives",
246 )
247 .into());
248 }
249 }
250
251 Ok(FragmentDeriveField {
252 directives,
253 raw_field: self,
254 })
255 }
256}
257
258impl FragmentDeriveField {
259 pub(super) fn type_check_mode(&self) -> CheckMode {
260 if *self.raw_field.flatten {
261 CheckMode::Flattening
262 } else if self.raw_field.recurse.is_some() {
263 CheckMode::Recursing
264 } else if *self.raw_field.spread {
265 CheckMode::Spreading
266 } else if self.has_default() {
267 CheckMode::Defaulted
268 } else if self.is_skippable() {
269 CheckMode::Skippable
270 } else {
271 CheckMode::OutputTypes
272 }
273 }
274
275 pub(super) fn is_skippable(&self) -> bool {
276 self.directives.iter().any(|directive| {
277 matches!(
278 directive,
279 FieldDirective::Include(_) | FieldDirective::Skip(_)
280 )
281 })
282 }
283
284 pub(super) fn spread(&self) -> bool {
285 *self.raw_field.spread
286 }
287
288 pub(super) fn ident(&self) -> Option<&proc_macro2::Ident> {
289 self.raw_field.ident.as_ref()
290 }
291
292 pub(super) fn graphql_ident(&self) -> RenamableFieldIdent {
293 let mut ident = RenamableFieldIdent::from(
294 self.raw_field
295 .ident
296 .clone()
297 .expect("FragmentDerive only supports named structs"),
298 );
299 if let Some(rename) = &self.raw_field.rename {
300 let span = rename.span();
301 let rename = (**rename).clone();
302 ident.set_rename(rename, span)
303 }
304 ident
305 }
306
307 pub(super) fn alias(&self) -> Option<String> {
308 self.raw_field.alias.then(|| {
309 self.raw_field
310 .ident
311 .as_ref()
312 .expect("ident is required")
313 .to_string()
314 })
315 }
316
317 pub(super) fn has_default(&self) -> bool {
318 *self.raw_field.default
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 use quote::format_ident;
327
328 #[test]
329 fn test_fragment_derive_validate_pass() {
330 let input = FragmentDeriveInput {
331 ident: format_ident!("TestInput"),
332 data: darling::ast::Data::Struct(darling::ast::Fields::new(
333 darling::ast::Style::Struct,
334 vec![
335 RawFragmentDeriveField {
336 ident: Some(format_ident!("field_one")),
337 ty: syn::parse_quote! { String },
338 attrs: vec![],
339 flatten: false.into(),
340 recurse: None,
341 spread: false.into(),
342 rename: None,
343 alias: false.into(),
344 feature: None,
345 default: false.into(),
346 },
347 RawFragmentDeriveField {
348 ident: Some(format_ident!("field_two")),
349 ty: syn::parse_quote! { String },
350 attrs: vec![],
351 flatten: true.into(),
352 recurse: None,
353 spread: false.into(),
354 rename: None,
355 alias: false.into(),
356 feature: None,
357 default: false.into(),
358 },
359 RawFragmentDeriveField {
360 ident: Some(format_ident!("field_three")),
361 ty: syn::parse_quote! { String },
362 attrs: vec![],
363 flatten: false.into(),
364 recurse: Some(8.into()),
365 spread: false.into(),
366 rename: Some("fieldThree".to_string().into()),
367 alias: false.into(),
368 feature: None,
369 default: false.into(),
370 },
371 RawFragmentDeriveField {
372 ident: Some(format_ident!("some_spread")),
373 ty: syn::parse_quote! { String },
374 attrs: vec![],
375 flatten: false.into(),
376 recurse: None,
377 spread: true.into(),
378 rename: Some("fieldThree".to_string().into()),
379 alias: true.into(),
380 feature: None,
381 default: false.into(),
382 },
383 ],
384 )),
385 generics: Default::default(),
386 schema: None,
387 schema_path: Some("abcd".to_string().into()),
388 schema_module_: None,
389 graphql_type: Some("abcd".to_string().into()),
390 variables: None,
391 no_deserialize: false,
392 };
393
394 assert!(input.validate().is_ok());
395 }
396
397 #[test]
398 fn test_fragment_derive_validate_fails() {
399 let input = FragmentDeriveInput {
400 ident: format_ident!("TestInput"),
401 data: darling::ast::Data::Struct(darling::ast::Fields::new(
402 darling::ast::Style::Struct,
403 vec![
404 RawFragmentDeriveField {
405 ident: Some(format_ident!("field_one")),
406 ty: syn::parse_quote! { String },
407 attrs: vec![],
408 flatten: false.into(),
409 recurse: None,
410 spread: false.into(),
411 rename: None,
412 alias: false.into(),
413 feature: None,
414 default: false.into(),
415 },
416 RawFragmentDeriveField {
417 ident: Some(format_ident!("field_two")),
418 ty: syn::parse_quote! { String },
419 attrs: vec![],
420 flatten: true.into(),
421 recurse: Some(8.into()),
422 spread: false.into(),
423 rename: None,
424 alias: false.into(),
425 feature: None,
426 default: false.into(),
427 },
428 RawFragmentDeriveField {
429 ident: Some(format_ident!("field_three")),
430 ty: syn::parse_quote! { String },
431 attrs: vec![],
432 flatten: true.into(),
433 recurse: Some(8.into()),
434 spread: false.into(),
435 rename: None,
436 alias: false.into(),
437 feature: None,
438 default: false.into(),
439 },
440 RawFragmentDeriveField {
441 ident: Some(format_ident!("some_spread")),
442 ty: syn::parse_quote! { String },
443 attrs: vec![],
444 flatten: true.into(),
445 recurse: None,
446 spread: true.into(),
447 rename: None,
448 alias: false.into(),
449 feature: None,
450 default: false.into(),
451 },
452 RawFragmentDeriveField {
453 ident: Some(format_ident!("some_other_spread")),
454 ty: syn::parse_quote! { String },
455 attrs: vec![],
456 flatten: false.into(),
457 recurse: Some(8.into()),
458 spread: true.into(),
459 rename: None,
460 alias: false.into(),
461 feature: None,
462 default: false.into(),
463 },
464 RawFragmentDeriveField {
465 ident: Some(format_ident!("some_other_spread")),
466 ty: syn::parse_quote! { String },
467 attrs: vec![],
468 flatten: false.into(),
469 recurse: Some(8.into()),
470 spread: true.into(),
471 rename: None,
472 alias: true.into(),
473 feature: None,
474 default: false.into(),
475 },
476 ],
477 )),
478 generics: Default::default(),
479 schema: None,
480 schema_path: Some("abcd".to_string().into()),
481 schema_module_: Some(syn::parse2(quote::quote! { abcd }).unwrap()),
482 graphql_type: Some("abcd".to_string().into()),
483 variables: None,
484 no_deserialize: false,
485 };
486
487 let errors = input.validate().map(|_| ()).unwrap_err();
488 assert_eq!(errors.len(), 5);
489 }
490
491 #[test]
492 fn test_fragment_derive_validate_failed() {
493 let input = FragmentDeriveInput {
494 ident: format_ident!("TestInput"),
495 data: darling::ast::Data::Struct(darling::ast::Fields::new(
496 darling::ast::Style::Struct,
497 vec![],
498 )),
499 generics: Default::default(),
500 schema: None,
501 schema_path: Some("abcd".to_string().into()),
502 schema_module_: Some(syn::parse2(quote::quote! { abcd }).unwrap()),
503 graphql_type: Some("abcd".to_string().into()),
504 variables: None,
505 no_deserialize: false,
506 };
507 let errors = input.validate().map(|_| ()).unwrap_err();
508 insta::assert_snapshot!(errors.to_compile_errors().to_string(), @r###":: core :: compile_error ! { "At least one field should be selected for `TestInput`." }"###);
509 }
510
511 #[test]
512 fn test_fragment_derive_validate_pass_no_graphql_type() {
513 let input = FragmentDeriveInput {
514 ident: format_ident!("TestInput"),
515 data: darling::ast::Data::Struct(darling::ast::Fields::new(
516 darling::ast::Style::Struct,
517 vec![
518 RawFragmentDeriveField {
519 ident: Some(format_ident!("field_one")),
520 ty: syn::parse_quote! { String },
521 attrs: vec![],
522 flatten: false.into(),
523 recurse: None,
524 spread: false.into(),
525 rename: None,
526 alias: false.into(),
527 feature: None,
528 default: false.into(),
529 },
530 RawFragmentDeriveField {
531 ident: Some(format_ident!("field_two")),
532 ty: syn::parse_quote! { String },
533 attrs: vec![],
534 flatten: true.into(),
535 recurse: None,
536 spread: false.into(),
537 rename: None,
538 alias: false.into(),
539 feature: None,
540 default: false.into(),
541 },
542 RawFragmentDeriveField {
543 ident: Some(format_ident!("field_three")),
544 ty: syn::parse_quote! { String },
545 attrs: vec![],
546 flatten: false.into(),
547 recurse: Some(8.into()),
548 spread: false.into(),
549 rename: None,
550 alias: false.into(),
551 feature: None,
552 default: false.into(),
553 },
554 ],
555 )),
556 generics: Default::default(),
557 schema: None,
558 schema_path: Some("abcd".to_string().into()),
559 schema_module_: Some(syn::parse2(quote::quote! { abcd }).unwrap()),
560 graphql_type: None,
561 variables: None,
562 no_deserialize: false,
563 };
564
565 assert!(input.validate().is_ok())
566 }
567}