Skip to main content

cynic_codegen/fragment_derive/
input.rs

1use super::directives::FieldDirective;
2
3use {
4    darling::util::SpannedValue, proc_macro2::Span, std::collections::HashSet,
5    syn::spanned::Spanned,
6};
7
8use crate::{
9    Errors,
10    idents::{RenamableFieldIdent, RenameAll},
11    schema::SchemaInput,
12    types::CheckMode,
13};
14
15#[derive(darling::FromDeriveInput)]
16#[darling(attributes(cynic), supports(struct_named))]
17pub struct FragmentDeriveInput {
18    pub(super) ident: proc_macro2::Ident,
19    pub(super) data: darling::ast::Data<(), RawFragmentDeriveField>,
20    pub(super) generics: syn::Generics,
21
22    #[darling(default)]
23    schema: Option<SpannedValue<String>>,
24    #[darling(default)]
25    schema_path: Option<SpannedValue<String>>,
26
27    #[darling(default, rename = "schema_module")]
28    schema_module_: Option<syn::Path>,
29
30    #[darling(default)]
31    pub graphql_type: Option<SpannedValue<String>>,
32
33    #[darling(default)]
34    pub(super) no_deserialize: bool,
35
36    #[darling(default)]
37    pub(super) rename_all: Option<RenameAll>,
38
39    #[darling(default)]
40    variables: Option<syn::Path>,
41}
42
43impl FragmentDeriveInput {
44    pub fn schema_module(&self) -> syn::Path {
45        if let Some(schema_module) = &self.schema_module_ {
46            return schema_module.clone();
47        }
48        syn::parse2(quote::quote! { schema }).unwrap()
49    }
50
51    pub fn graphql_type_name(&self) -> String {
52        self.graphql_type
53            .as_ref()
54            .map(|sp| sp.to_string())
55            .unwrap_or_else(|| self.ident.to_string())
56    }
57
58    pub fn graphql_type_span(&self) -> Span {
59        self.graphql_type
60            .as_ref()
61            .map(|val| val.span())
62            .unwrap_or_else(|| self.ident.span())
63    }
64
65    pub fn validate(&self) -> Result<Vec<FragmentDeriveField>, Errors> {
66        let data_field_is_empty = matches!(self.data.clone(), darling::ast::Data::Struct(fields) if fields.fields.is_empty());
67        if data_field_is_empty {
68            return Err(syn::Error::new(
69                self.ident.span(),
70                format!(
71                    "At least one field should be selected for `{}`.",
72                    self.ident
73                ),
74            )
75            .into());
76        }
77
78        let mut fields = vec![];
79        let mut errors = Errors::default();
80
81        let results = self
82            .data
83            .clone()
84            .map_struct_fields(|field| field.validate())
85            .take_struct()
86            .unwrap()
87            .into_iter();
88
89        for result in results {
90            match result {
91                Ok(field) => fields.push(field),
92                Err(error) => errors.extend(error),
93            }
94        }
95
96        if !errors.is_empty() {
97            return Err(errors);
98        }
99
100        Ok(fields)
101    }
102
103    pub fn detect_aliases(&mut self) {
104        let mut names = HashSet::new();
105        if let darling::ast::Data::Struct(fields) = &mut self.data {
106            for field in &mut fields.fields {
107                if let Some(rename) = &mut field.rename {
108                    let name = rename.as_str();
109                    if names.contains(name) {
110                        field.alias = true.into();
111                        continue;
112                    }
113                    names.insert(name);
114                }
115            }
116        }
117    }
118
119    pub fn variables(&self) -> Option<syn::Path> {
120        self.variables.clone()
121    }
122
123    pub fn schema_input(&self) -> Result<SchemaInput, syn::Error> {
124        match (&self.schema, &self.schema_path) {
125            (None, None) => SchemaInput::default().map_err(|e| e.into_syn_error(Span::call_site())),
126            (None, Some(path)) => SchemaInput::from_schema_path(path.as_ref())
127                .map_err(|e| e.into_syn_error(path.span())),
128            (Some(name), None) => SchemaInput::from_schema_name(name.as_ref())
129                .map_err(|e| e.into_syn_error(name.span())),
130            (Some(_), Some(path)) => Err(syn::Error::new(
131                path.span(),
132                "Only one of schema_path & schema can be provided",
133            )),
134        }
135    }
136}
137
138#[derive(darling::FromField, Clone)]
139#[darling(attributes(cynic), forward_attrs(arguments, directives))]
140pub struct RawFragmentDeriveField {
141    pub(super) ident: Option<proc_macro2::Ident>,
142    pub(super) ty: syn::Type,
143
144    pub(super) attrs: Vec<syn::Attribute>,
145
146    #[darling(default)]
147    pub(super) flatten: SpannedValue<bool>,
148
149    #[darling(default)]
150    pub(super) recurse: Option<SpannedValue<u8>>,
151
152    #[darling(default)]
153    pub(super) spread: SpannedValue<bool>,
154
155    #[darling(default)]
156    rename: Option<SpannedValue<String>>,
157
158    #[darling(default)]
159    alias: SpannedValue<bool>,
160
161    #[darling(default)]
162    pub(super) feature: Option<SpannedValue<String>>,
163
164    #[darling(default)]
165    pub(super) default: SpannedValue<bool>,
166}
167
168pub struct FragmentDeriveField {
169    pub(super) raw_field: RawFragmentDeriveField,
170
171    pub(super) directives: Vec<super::directives::FieldDirective>,
172}
173
174impl RawFragmentDeriveField {
175    pub fn validate(self) -> Result<FragmentDeriveField, Errors> {
176        if *self.flatten && self.recurse.is_some() {
177            return Err(syn::Error::new(
178                self.recurse.as_ref().unwrap().span(),
179                "A field can't be recurse if it's being flattened",
180            )
181            .into());
182        }
183
184        if *self.flatten && *self.spread {
185            return Err(syn::Error::new(
186                self.flatten.span(),
187                "A field can't be flattened if it's also being spread",
188            )
189            .into());
190        }
191
192        if *self.spread && self.recurse.is_some() {
193            return Err(syn::Error::new(
194                self.recurse.as_ref().unwrap().span(),
195                "A field can't be recurse if it's being spread",
196            )
197            .into());
198        }
199
200        if *self.alias && self.rename.is_none() {
201            return Err(syn::Error::new(
202                self.alias.span(),
203                "You can only alias a renamed field.  Try removing `alias` or adding a rename",
204            )
205            .into());
206        }
207
208        if *self.default && *self.spread {
209            return Err(syn::Error::new(
210                self.default.span(),
211                "A field can't be defaulted if it's also being spread",
212            )
213            .into());
214        }
215
216        if *self.default && self.recurse.is_some() {
217            return Err(syn::Error::new(
218                self.recurse.unwrap().span(),
219                "A field can't be recurse if it's also being defaulted",
220            )
221            .into());
222        }
223
224        if *self.default && *self.flatten {
225            return Err(syn::Error::new(
226                self.default.span(),
227                "A field can't be defaulted if it's being flattened",
228            )
229            .into());
230        }
231
232        let directives = super::directives::directives_from_field_attrs(&self.attrs)?;
233        let skippable = directives.iter().any(|directive| {
234            matches!(
235                directive,
236                FieldDirective::Include(_) | FieldDirective::Skip(_)
237            )
238        });
239
240        if skippable {
241            if *self.spread {
242                return Err(syn::Error::new(
243                    self.spread.span(),
244                    "spread can't currently be used on fields with skip or include directives",
245                )
246                .into());
247            } else if *self.flatten {
248                return Err(syn::Error::new(
249                    self.flatten.span(),
250                    "flatten can't currently be used on fields with skip or include directives",
251                )
252                .into());
253            } else if let Some(recurse) = self.recurse {
254                return Err(syn::Error::new(
255                    recurse.span(),
256                    "recurse can't currently be used on fields with skip or include directives",
257                )
258                .into());
259            }
260        }
261
262        Ok(FragmentDeriveField {
263            directives,
264            raw_field: self,
265        })
266    }
267}
268
269impl FragmentDeriveField {
270    pub(super) fn type_check_mode(&self) -> CheckMode {
271        if *self.raw_field.flatten {
272            CheckMode::Flattening
273        } else if self.raw_field.recurse.is_some() {
274            CheckMode::Recursing
275        } else if *self.raw_field.spread {
276            CheckMode::Spreading
277        } else if self.has_default() {
278            CheckMode::Defaulted
279        } else if self.is_skippable() {
280            CheckMode::Skippable
281        } else {
282            CheckMode::OutputTypes
283        }
284    }
285
286    pub(super) fn is_skippable(&self) -> bool {
287        self.directives.iter().any(|directive| {
288            matches!(
289                directive,
290                FieldDirective::Include(_) | FieldDirective::Skip(_)
291            )
292        })
293    }
294
295    pub(super) fn spread(&self) -> bool {
296        *self.raw_field.spread
297    }
298
299    pub(super) fn ident(&self) -> Option<&proc_macro2::Ident> {
300        self.raw_field.ident.as_ref()
301    }
302
303    pub(super) fn graphql_ident(&self, rename_rule: RenameAll) -> RenamableFieldIdent {
304        let mut ident = RenamableFieldIdent::from(
305            self.raw_field
306                .ident
307                .clone()
308                .expect("FragmentDerive only supports named structs"),
309        );
310        if let Some(rename) = &self.raw_field.rename {
311            let span = rename.span();
312            let rename = (**rename).clone();
313            ident.set_rename(rename, span)
314        } else {
315            ident.rename_with(rename_rule, self.raw_field.ident.span());
316        }
317        ident
318    }
319
320    pub(super) fn alias(&self) -> Option<String> {
321        self.raw_field.alias.then(|| {
322            self.raw_field
323                .ident
324                .as_ref()
325                .expect("ident is required")
326                .to_string()
327        })
328    }
329
330    pub(super) fn has_default(&self) -> bool {
331        *self.raw_field.default
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    use quote::format_ident;
340
341    #[test]
342    fn test_fragment_derive_validate_pass() {
343        let input = FragmentDeriveInput {
344            ident: format_ident!("TestInput"),
345            data: darling::ast::Data::Struct(darling::ast::Fields::new(
346                darling::ast::Style::Struct,
347                vec![
348                    RawFragmentDeriveField {
349                        ident: Some(format_ident!("field_one")),
350                        ty: syn::parse_quote! { String },
351                        attrs: vec![],
352                        flatten: false.into(),
353                        recurse: None,
354                        spread: false.into(),
355                        rename: None,
356                        alias: false.into(),
357                        feature: None,
358                        default: false.into(),
359                    },
360                    RawFragmentDeriveField {
361                        ident: Some(format_ident!("field_two")),
362                        ty: syn::parse_quote! { String },
363                        attrs: vec![],
364                        flatten: true.into(),
365                        recurse: None,
366                        spread: false.into(),
367                        rename: None,
368                        alias: false.into(),
369                        feature: None,
370                        default: false.into(),
371                    },
372                    RawFragmentDeriveField {
373                        ident: Some(format_ident!("field_three")),
374                        ty: syn::parse_quote! { String },
375                        attrs: vec![],
376                        flatten: false.into(),
377                        recurse: Some(8.into()),
378                        spread: false.into(),
379                        rename: Some("fieldThree".to_string().into()),
380                        alias: false.into(),
381                        feature: None,
382                        default: false.into(),
383                    },
384                    RawFragmentDeriveField {
385                        ident: Some(format_ident!("some_spread")),
386                        ty: syn::parse_quote! { String },
387                        attrs: vec![],
388                        flatten: false.into(),
389                        recurse: None,
390                        spread: true.into(),
391                        rename: Some("fieldThree".to_string().into()),
392                        alias: true.into(),
393                        feature: None,
394                        default: false.into(),
395                    },
396                ],
397            )),
398            generics: Default::default(),
399            schema: None,
400            schema_path: Some("abcd".to_string().into()),
401            schema_module_: None,
402            graphql_type: Some("abcd".to_string().into()),
403            variables: None,
404            no_deserialize: false,
405            rename_all: None,
406        };
407
408        assert!(input.validate().is_ok());
409    }
410
411    #[test]
412    fn test_fragment_derive_validate_fails() {
413        let input = FragmentDeriveInput {
414            ident: format_ident!("TestInput"),
415            data: darling::ast::Data::Struct(darling::ast::Fields::new(
416                darling::ast::Style::Struct,
417                vec![
418                    RawFragmentDeriveField {
419                        ident: Some(format_ident!("field_one")),
420                        ty: syn::parse_quote! { String },
421                        attrs: vec![],
422                        flatten: false.into(),
423                        recurse: None,
424                        spread: false.into(),
425                        rename: None,
426                        alias: false.into(),
427                        feature: None,
428                        default: false.into(),
429                    },
430                    RawFragmentDeriveField {
431                        ident: Some(format_ident!("field_two")),
432                        ty: syn::parse_quote! { String },
433                        attrs: vec![],
434                        flatten: true.into(),
435                        recurse: Some(8.into()),
436                        spread: false.into(),
437                        rename: None,
438                        alias: false.into(),
439                        feature: None,
440                        default: false.into(),
441                    },
442                    RawFragmentDeriveField {
443                        ident: Some(format_ident!("field_three")),
444                        ty: syn::parse_quote! { String },
445                        attrs: vec![],
446                        flatten: true.into(),
447                        recurse: Some(8.into()),
448                        spread: false.into(),
449                        rename: None,
450                        alias: false.into(),
451                        feature: None,
452                        default: false.into(),
453                    },
454                    RawFragmentDeriveField {
455                        ident: Some(format_ident!("some_spread")),
456                        ty: syn::parse_quote! { String },
457                        attrs: vec![],
458                        flatten: true.into(),
459                        recurse: None,
460                        spread: true.into(),
461                        rename: None,
462                        alias: false.into(),
463                        feature: None,
464                        default: false.into(),
465                    },
466                    RawFragmentDeriveField {
467                        ident: Some(format_ident!("some_other_spread")),
468                        ty: syn::parse_quote! { String },
469                        attrs: vec![],
470                        flatten: false.into(),
471                        recurse: Some(8.into()),
472                        spread: true.into(),
473                        rename: None,
474                        alias: false.into(),
475                        feature: None,
476                        default: false.into(),
477                    },
478                    RawFragmentDeriveField {
479                        ident: Some(format_ident!("some_other_spread")),
480                        ty: syn::parse_quote! { String },
481                        attrs: vec![],
482                        flatten: false.into(),
483                        recurse: Some(8.into()),
484                        spread: true.into(),
485                        rename: None,
486                        alias: true.into(),
487                        feature: None,
488                        default: false.into(),
489                    },
490                ],
491            )),
492            generics: Default::default(),
493            schema: None,
494            schema_path: Some("abcd".to_string().into()),
495            schema_module_: Some(syn::parse2(quote::quote! { abcd }).unwrap()),
496            graphql_type: Some("abcd".to_string().into()),
497            variables: None,
498            no_deserialize: false,
499            rename_all: None,
500        };
501
502        let errors = input.validate().map(|_| ()).unwrap_err();
503        assert_eq!(errors.len(), 5);
504    }
505
506    #[test]
507    fn test_fragment_derive_validate_failed() {
508        let input = FragmentDeriveInput {
509            ident: format_ident!("TestInput"),
510            data: darling::ast::Data::Struct(darling::ast::Fields::new(
511                darling::ast::Style::Struct,
512                vec![],
513            )),
514            generics: Default::default(),
515            schema: None,
516            schema_path: Some("abcd".to_string().into()),
517            schema_module_: Some(syn::parse2(quote::quote! { abcd }).unwrap()),
518            graphql_type: Some("abcd".to_string().into()),
519            variables: None,
520            no_deserialize: false,
521            rename_all: None,
522        };
523        let errors = input.validate().map(|_| ()).unwrap_err();
524        insta::assert_snapshot!(errors.to_compile_errors().to_string(), @r###":: core :: compile_error ! { "At least one field should be selected for `TestInput`." }"###);
525    }
526
527    #[test]
528    fn test_fragment_derive_validate_pass_no_graphql_type() {
529        let input = FragmentDeriveInput {
530            ident: format_ident!("TestInput"),
531            data: darling::ast::Data::Struct(darling::ast::Fields::new(
532                darling::ast::Style::Struct,
533                vec![
534                    RawFragmentDeriveField {
535                        ident: Some(format_ident!("field_one")),
536                        ty: syn::parse_quote! { String },
537                        attrs: vec![],
538                        flatten: false.into(),
539                        recurse: None,
540                        spread: false.into(),
541                        rename: None,
542                        alias: false.into(),
543                        feature: None,
544                        default: false.into(),
545                    },
546                    RawFragmentDeriveField {
547                        ident: Some(format_ident!("field_two")),
548                        ty: syn::parse_quote! { String },
549                        attrs: vec![],
550                        flatten: true.into(),
551                        recurse: None,
552                        spread: false.into(),
553                        rename: None,
554                        alias: false.into(),
555                        feature: None,
556                        default: false.into(),
557                    },
558                    RawFragmentDeriveField {
559                        ident: Some(format_ident!("field_three")),
560                        ty: syn::parse_quote! { String },
561                        attrs: vec![],
562                        flatten: false.into(),
563                        recurse: Some(8.into()),
564                        spread: false.into(),
565                        rename: None,
566                        alias: false.into(),
567                        feature: None,
568                        default: false.into(),
569                    },
570                ],
571            )),
572            generics: Default::default(),
573            schema: None,
574            schema_path: Some("abcd".to_string().into()),
575            schema_module_: Some(syn::parse2(quote::quote! { abcd }).unwrap()),
576            graphql_type: None,
577            variables: None,
578            no_deserialize: false,
579            rename_all: None,
580        };
581
582        assert!(input.validate().is_ok())
583    }
584}