cynic_codegen/fragment_derive/
input.rs

1use super::directives::FieldDirective;
2
3use {darling::util::SpannedValue, proc_macro2::Span, std::collections::HashSet};
4
5use crate::{idents::RenamableFieldIdent, schema::SchemaInput, types::CheckMode, Errors};
6
7#[derive(darling::FromDeriveInput)]
8#[darling(attributes(cynic), supports(struct_named))]
9pub struct FragmentDeriveInput {
10    pub(super) ident: proc_macro2::Ident,
11    pub(super) data: darling::ast::Data<(), RawFragmentDeriveField>,
12    pub(super) generics: syn::Generics,
13
14    #[darling(default)]
15    schema: Option<SpannedValue<String>>,
16    #[darling(default)]
17    schema_path: Option<SpannedValue<String>>,
18
19    #[darling(default, rename = "schema_module")]
20    schema_module_: Option<syn::Path>,
21
22    #[darling(default)]
23    pub graphql_type: Option<SpannedValue<String>>,
24
25    #[darling(default)]
26    pub(super) no_deserialize: bool,
27
28    #[darling(default)]
29    variables: Option<syn::Path>,
30}
31
32impl FragmentDeriveInput {
33    pub fn schema_module(&self) -> syn::Path {
34        if let Some(schema_module) = &self.schema_module_ {
35            return schema_module.clone();
36        }
37        syn::parse2(quote::quote! { schema }).unwrap()
38    }
39
40    pub fn graphql_type_name(&self) -> String {
41        self.graphql_type
42            .as_ref()
43            .map(|sp| sp.to_string())
44            .unwrap_or_else(|| self.ident.to_string())
45    }
46
47    pub fn graphql_type_span(&self) -> Span {
48        self.graphql_type
49            .as_ref()
50            .map(|val| val.span())
51            .unwrap_or_else(|| self.ident.span())
52    }
53
54    pub fn validate(&self) -> Result<Vec<FragmentDeriveField>, Errors> {
55        let data_field_is_empty = matches!(self.data.clone(), darling::ast::Data::Struct(fields) if fields.fields.is_empty());
56        if data_field_is_empty {
57            return Err(syn::Error::new(
58                self.ident.span(),
59                format!(
60                    "At least one field should be selected for `{}`.",
61                    self.ident
62                ),
63            )
64            .into());
65        }
66
67        let mut fields = vec![];
68        let mut errors = Errors::default();
69
70        let results = self
71            .data
72            .clone()
73            .map_struct_fields(|field| field.validate())
74            .take_struct()
75            .unwrap()
76            .into_iter();
77
78        for result in results {
79            match result {
80                Ok(field) => fields.push(field),
81                Err(error) => errors.extend(error),
82            }
83        }
84
85        if !errors.is_empty() {
86            return Err(errors);
87        }
88
89        Ok(fields)
90    }
91
92    pub fn detect_aliases(&mut self) {
93        let mut names = HashSet::new();
94        if let darling::ast::Data::Struct(fields) = &mut self.data {
95            for field in &mut fields.fields {
96                if let Some(rename) = &mut field.rename {
97                    let name = rename.as_str();
98                    if names.contains(name) {
99                        field.alias = true.into();
100                        continue;
101                    }
102                    names.insert(name);
103                }
104            }
105        }
106    }
107
108    pub fn variables(&self) -> Option<syn::Path> {
109        self.variables.clone()
110    }
111
112    pub fn schema_input(&self) -> Result<SchemaInput, syn::Error> {
113        match (&self.schema, &self.schema_path) {
114            (None, None) => SchemaInput::default().map_err(|e| e.into_syn_error(Span::call_site())),
115            (None, Some(path)) => SchemaInput::from_schema_path(path.as_ref())
116                .map_err(|e| e.into_syn_error(path.span())),
117            (Some(name), None) => SchemaInput::from_schema_name(name.as_ref())
118                .map_err(|e| e.into_syn_error(name.span())),
119            (Some(_), Some(path)) => Err(syn::Error::new(
120                path.span(),
121                "Only one of schema_path & schema can be provided",
122            )),
123        }
124    }
125}
126
127#[derive(darling::FromField, Clone)]
128#[darling(attributes(cynic), forward_attrs(arguments, directives))]
129pub struct RawFragmentDeriveField {
130    pub(super) ident: Option<proc_macro2::Ident>,
131    pub(super) ty: syn::Type,
132
133    pub(super) attrs: Vec<syn::Attribute>,
134
135    #[darling(default)]
136    pub(super) flatten: SpannedValue<bool>,
137
138    #[darling(default)]
139    pub(super) recurse: Option<SpannedValue<u8>>,
140
141    #[darling(default)]
142    pub(super) spread: SpannedValue<bool>,
143
144    #[darling(default)]
145    rename: Option<SpannedValue<String>>,
146
147    #[darling(default)]
148    alias: SpannedValue<bool>,
149
150    #[darling(default)]
151    pub(super) feature: Option<SpannedValue<String>>,
152
153    #[darling(default)]
154    pub(super) default: SpannedValue<bool>,
155}
156
157pub struct FragmentDeriveField {
158    pub(super) raw_field: RawFragmentDeriveField,
159
160    pub(super) directives: Vec<super::directives::FieldDirective>,
161}
162
163impl RawFragmentDeriveField {
164    pub fn validate(self) -> Result<FragmentDeriveField, Errors> {
165        if *self.flatten && self.recurse.is_some() {
166            return Err(syn::Error::new(
167                self.recurse.as_ref().unwrap().span(),
168                "A field can't be recurse if it's being flattened",
169            )
170            .into());
171        }
172
173        if *self.flatten && *self.spread {
174            return Err(syn::Error::new(
175                self.flatten.span(),
176                "A field can't be flattened if it's also being spread",
177            )
178            .into());
179        }
180
181        if *self.spread && self.recurse.is_some() {
182            return Err(syn::Error::new(
183                self.recurse.as_ref().unwrap().span(),
184                "A field can't be recurse if it's being spread",
185            )
186            .into());
187        }
188
189        if *self.alias && self.rename.is_none() {
190            return Err(syn::Error::new(
191                self.alias.span(),
192                "You can only alias a renamed field.  Try removing `alias` or adding a rename",
193            )
194            .into());
195        }
196
197        if *self.default && *self.spread {
198            return Err(syn::Error::new(
199                self.default.span(),
200                "A field can't be defaulted if it's also being spread",
201            )
202            .into());
203        }
204
205        if *self.default && self.recurse.is_some() {
206            return Err(syn::Error::new(
207                self.recurse.unwrap().span(),
208                "A field can't be recurse if it's also being defaulted",
209            )
210            .into());
211        }
212
213        if *self.default && *self.flatten {
214            return Err(syn::Error::new(
215                self.default.span(),
216                "A field can't be defaulted if it's being flattened",
217            )
218            .into());
219        }
220
221        let directives = super::directives::directives_from_field_attrs(&self.attrs)?;
222        let skippable = directives.iter().any(|directive| {
223            matches!(
224                directive,
225                FieldDirective::Include(_) | FieldDirective::Skip(_)
226            )
227        });
228
229        if skippable {
230            if *self.spread {
231                return Err(syn::Error::new(
232                    self.spread.span(),
233                    "spread can't currently be used on fields with skip or include directives",
234                )
235                .into());
236            } else if *self.flatten {
237                return Err(syn::Error::new(
238                    self.flatten.span(),
239                    "flatten can't currently be used on fields with skip or include directives",
240                )
241                .into());
242            } else if let Some(recurse) = self.recurse {
243                return Err(syn::Error::new(
244                    recurse.span(),
245                    "recurse can't currently be used on fields with skip or include directives",
246                )
247                .into());
248            }
249        }
250
251        Ok(FragmentDeriveField {
252            directives,
253            raw_field: self,
254        })
255    }
256}
257
258impl FragmentDeriveField {
259    pub(super) fn type_check_mode(&self) -> CheckMode {
260        if *self.raw_field.flatten {
261            CheckMode::Flattening
262        } else if self.raw_field.recurse.is_some() {
263            CheckMode::Recursing
264        } else if *self.raw_field.spread {
265            CheckMode::Spreading
266        } else if self.has_default() {
267            CheckMode::Defaulted
268        } else if self.is_skippable() {
269            CheckMode::Skippable
270        } else {
271            CheckMode::OutputTypes
272        }
273    }
274
275    pub(super) fn is_skippable(&self) -> bool {
276        self.directives.iter().any(|directive| {
277            matches!(
278                directive,
279                FieldDirective::Include(_) | FieldDirective::Skip(_)
280            )
281        })
282    }
283
284    pub(super) fn spread(&self) -> bool {
285        *self.raw_field.spread
286    }
287
288    pub(super) fn ident(&self) -> Option<&proc_macro2::Ident> {
289        self.raw_field.ident.as_ref()
290    }
291
292    pub(super) fn graphql_ident(&self) -> RenamableFieldIdent {
293        let mut ident = RenamableFieldIdent::from(
294            self.raw_field
295                .ident
296                .clone()
297                .expect("FragmentDerive only supports named structs"),
298        );
299        if let Some(rename) = &self.raw_field.rename {
300            let span = rename.span();
301            let rename = (**rename).clone();
302            ident.set_rename(rename, span)
303        }
304        ident
305    }
306
307    pub(super) fn alias(&self) -> Option<String> {
308        self.raw_field.alias.then(|| {
309            self.raw_field
310                .ident
311                .as_ref()
312                .expect("ident is required")
313                .to_string()
314        })
315    }
316
317    pub(super) fn has_default(&self) -> bool {
318        *self.raw_field.default
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    use quote::format_ident;
327
328    #[test]
329    fn test_fragment_derive_validate_pass() {
330        let input = FragmentDeriveInput {
331            ident: format_ident!("TestInput"),
332            data: darling::ast::Data::Struct(darling::ast::Fields::new(
333                darling::ast::Style::Struct,
334                vec![
335                    RawFragmentDeriveField {
336                        ident: Some(format_ident!("field_one")),
337                        ty: syn::parse_quote! { String },
338                        attrs: vec![],
339                        flatten: false.into(),
340                        recurse: None,
341                        spread: false.into(),
342                        rename: None,
343                        alias: false.into(),
344                        feature: None,
345                        default: false.into(),
346                    },
347                    RawFragmentDeriveField {
348                        ident: Some(format_ident!("field_two")),
349                        ty: syn::parse_quote! { String },
350                        attrs: vec![],
351                        flatten: true.into(),
352                        recurse: None,
353                        spread: false.into(),
354                        rename: None,
355                        alias: false.into(),
356                        feature: None,
357                        default: false.into(),
358                    },
359                    RawFragmentDeriveField {
360                        ident: Some(format_ident!("field_three")),
361                        ty: syn::parse_quote! { String },
362                        attrs: vec![],
363                        flatten: false.into(),
364                        recurse: Some(8.into()),
365                        spread: false.into(),
366                        rename: Some("fieldThree".to_string().into()),
367                        alias: false.into(),
368                        feature: None,
369                        default: false.into(),
370                    },
371                    RawFragmentDeriveField {
372                        ident: Some(format_ident!("some_spread")),
373                        ty: syn::parse_quote! { String },
374                        attrs: vec![],
375                        flatten: false.into(),
376                        recurse: None,
377                        spread: true.into(),
378                        rename: Some("fieldThree".to_string().into()),
379                        alias: true.into(),
380                        feature: None,
381                        default: false.into(),
382                    },
383                ],
384            )),
385            generics: Default::default(),
386            schema: None,
387            schema_path: Some("abcd".to_string().into()),
388            schema_module_: None,
389            graphql_type: Some("abcd".to_string().into()),
390            variables: None,
391            no_deserialize: false,
392        };
393
394        assert!(input.validate().is_ok());
395    }
396
397    #[test]
398    fn test_fragment_derive_validate_fails() {
399        let input = FragmentDeriveInput {
400            ident: format_ident!("TestInput"),
401            data: darling::ast::Data::Struct(darling::ast::Fields::new(
402                darling::ast::Style::Struct,
403                vec![
404                    RawFragmentDeriveField {
405                        ident: Some(format_ident!("field_one")),
406                        ty: syn::parse_quote! { String },
407                        attrs: vec![],
408                        flatten: false.into(),
409                        recurse: None,
410                        spread: false.into(),
411                        rename: None,
412                        alias: false.into(),
413                        feature: None,
414                        default: false.into(),
415                    },
416                    RawFragmentDeriveField {
417                        ident: Some(format_ident!("field_two")),
418                        ty: syn::parse_quote! { String },
419                        attrs: vec![],
420                        flatten: true.into(),
421                        recurse: Some(8.into()),
422                        spread: false.into(),
423                        rename: None,
424                        alias: false.into(),
425                        feature: None,
426                        default: false.into(),
427                    },
428                    RawFragmentDeriveField {
429                        ident: Some(format_ident!("field_three")),
430                        ty: syn::parse_quote! { String },
431                        attrs: vec![],
432                        flatten: true.into(),
433                        recurse: Some(8.into()),
434                        spread: false.into(),
435                        rename: None,
436                        alias: false.into(),
437                        feature: None,
438                        default: false.into(),
439                    },
440                    RawFragmentDeriveField {
441                        ident: Some(format_ident!("some_spread")),
442                        ty: syn::parse_quote! { String },
443                        attrs: vec![],
444                        flatten: true.into(),
445                        recurse: None,
446                        spread: true.into(),
447                        rename: None,
448                        alias: false.into(),
449                        feature: None,
450                        default: false.into(),
451                    },
452                    RawFragmentDeriveField {
453                        ident: Some(format_ident!("some_other_spread")),
454                        ty: syn::parse_quote! { String },
455                        attrs: vec![],
456                        flatten: false.into(),
457                        recurse: Some(8.into()),
458                        spread: true.into(),
459                        rename: None,
460                        alias: false.into(),
461                        feature: None,
462                        default: false.into(),
463                    },
464                    RawFragmentDeriveField {
465                        ident: Some(format_ident!("some_other_spread")),
466                        ty: syn::parse_quote! { String },
467                        attrs: vec![],
468                        flatten: false.into(),
469                        recurse: Some(8.into()),
470                        spread: true.into(),
471                        rename: None,
472                        alias: true.into(),
473                        feature: None,
474                        default: false.into(),
475                    },
476                ],
477            )),
478            generics: Default::default(),
479            schema: None,
480            schema_path: Some("abcd".to_string().into()),
481            schema_module_: Some(syn::parse2(quote::quote! { abcd }).unwrap()),
482            graphql_type: Some("abcd".to_string().into()),
483            variables: None,
484            no_deserialize: false,
485        };
486
487        let errors = input.validate().map(|_| ()).unwrap_err();
488        assert_eq!(errors.len(), 5);
489    }
490
491    #[test]
492    fn test_fragment_derive_validate_failed() {
493        let input = FragmentDeriveInput {
494            ident: format_ident!("TestInput"),
495            data: darling::ast::Data::Struct(darling::ast::Fields::new(
496                darling::ast::Style::Struct,
497                vec![],
498            )),
499            generics: Default::default(),
500            schema: None,
501            schema_path: Some("abcd".to_string().into()),
502            schema_module_: Some(syn::parse2(quote::quote! { abcd }).unwrap()),
503            graphql_type: Some("abcd".to_string().into()),
504            variables: None,
505            no_deserialize: false,
506        };
507        let errors = input.validate().map(|_| ()).unwrap_err();
508        insta::assert_snapshot!(errors.to_compile_errors().to_string(), @r###":: core :: compile_error ! { "At least one field should be selected for `TestInput`." }"###);
509    }
510
511    #[test]
512    fn test_fragment_derive_validate_pass_no_graphql_type() {
513        let input = FragmentDeriveInput {
514            ident: format_ident!("TestInput"),
515            data: darling::ast::Data::Struct(darling::ast::Fields::new(
516                darling::ast::Style::Struct,
517                vec![
518                    RawFragmentDeriveField {
519                        ident: Some(format_ident!("field_one")),
520                        ty: syn::parse_quote! { String },
521                        attrs: vec![],
522                        flatten: false.into(),
523                        recurse: None,
524                        spread: false.into(),
525                        rename: None,
526                        alias: false.into(),
527                        feature: None,
528                        default: false.into(),
529                    },
530                    RawFragmentDeriveField {
531                        ident: Some(format_ident!("field_two")),
532                        ty: syn::parse_quote! { String },
533                        attrs: vec![],
534                        flatten: true.into(),
535                        recurse: None,
536                        spread: false.into(),
537                        rename: None,
538                        alias: false.into(),
539                        feature: None,
540                        default: false.into(),
541                    },
542                    RawFragmentDeriveField {
543                        ident: Some(format_ident!("field_three")),
544                        ty: syn::parse_quote! { String },
545                        attrs: vec![],
546                        flatten: false.into(),
547                        recurse: Some(8.into()),
548                        spread: false.into(),
549                        rename: None,
550                        alias: false.into(),
551                        feature: None,
552                        default: false.into(),
553                    },
554                ],
555            )),
556            generics: Default::default(),
557            schema: None,
558            schema_path: Some("abcd".to_string().into()),
559            schema_module_: Some(syn::parse2(quote::quote! { abcd }).unwrap()),
560            graphql_type: None,
561            variables: None,
562            no_deserialize: false,
563        };
564
565        assert!(input.validate().is_ok())
566    }
567}