cynic_codegen/inline_fragments_derive/
mod.rs

1use {
2    darling::util::SpannedValue,
3    proc_macro2::{Span, TokenStream},
4};
5
6use crate::{
7    inline_fragments_derive::input::ValidationMode,
8    schema::{
9        markers::TypeMarkerIdent,
10        types::{InterfaceType, Kind, Type, UnionType},
11        Schema, SchemaError,
12    },
13    variables_fields_path, Errors,
14};
15
16pub mod input;
17
18mod exhaustiveness;
19mod inline_fragments_impl;
20
21#[cfg(test)]
22mod tests;
23
24pub use input::InlineFragmentsDeriveInput;
25
26use input::InlineFragmentsDeriveVariant;
27
28use self::inline_fragments_impl::Fallback;
29
30pub fn inline_fragments_derive(ast: &syn::DeriveInput) -> Result<TokenStream, Errors> {
31    use darling::FromDeriveInput;
32
33    match InlineFragmentsDeriveInput::from_derive_input(ast) {
34        Ok(input) => inline_fragments_derive_impl(input),
35        Err(e) => Ok(e.write_errors()),
36    }
37}
38
39pub(crate) fn inline_fragments_derive_impl(
40    input: InlineFragmentsDeriveInput,
41) -> Result<TokenStream, Errors> {
42    use quote::quote;
43
44    let schema = Schema::new(input.schema_input()?);
45
46    let target_type = schema.lookup::<InlineFragmentType<'_>>(&input.graphql_type_name())?;
47
48    input.validate(match target_type {
49        InlineFragmentType::Union(_) => ValidationMode::Union,
50        InlineFragmentType::Interface(_) => ValidationMode::Interface,
51    })?;
52
53    let variables = input.variables();
54
55    if let darling::ast::Data::Enum(variants) = &input.data {
56        if input.exhaustive.map(|e| *e).unwrap_or_default() {
57            exhaustiveness::exhaustiveness_check(variants, &target_type)?;
58        }
59
60        let fallback = check_fallback(variants, &target_type)?;
61
62        let type_lock = target_type.marker_ident().to_path(&input.schema_module());
63
64        let fragments = fragments_from_variants(variants)?;
65
66        let query_fragment_impl = QueryFragmentImpl {
67            target_enum: input.ident.clone(),
68            generics: &input.generics,
69            type_lock,
70            variables,
71            fragments: &fragments,
72            graphql_type_name: input.graphql_type_name(),
73            fallback: fallback.clone(),
74        };
75
76        let inline_fragments_impl = inline_fragments_impl::InlineFragmentsImpl {
77            target_enum: input.ident.clone(),
78            generics: &input.generics,
79            fragments: &fragments,
80            fallback,
81        };
82
83        Ok(quote! {
84            #inline_fragments_impl
85            #query_fragment_impl
86        })
87    } else {
88        Err(syn::Error::new(
89            Span::call_site(),
90            "InlineFragments can only be derived from an enum".to_string(),
91        )
92        .into())
93    }
94}
95
96struct Fragment {
97    rust_variant_name: syn::Ident,
98    inner_type: syn::Type,
99}
100
101fn fragments_from_variants(
102    variants: &[SpannedValue<InlineFragmentsDeriveVariant>],
103) -> Result<Vec<Fragment>, syn::Error> {
104    let mut result = vec![];
105    for variant in variants {
106        if *variant.fallback {
107            continue;
108        }
109
110        if variant.fields.style != darling::ast::Style::Tuple || variant.fields.fields.len() != 1 {
111            return Err(syn::Error::new(
112                variant.span(),
113                "InlineFragments derive requires enum variants to have one unnamed field",
114            ));
115        }
116        let field = variant.fields.fields.first().unwrap();
117        result.push(Fragment {
118            rust_variant_name: variant.ident.clone(),
119            inner_type: field.ty.clone(),
120        });
121    }
122    Ok(result)
123}
124
125fn check_fallback(
126    variants: &[SpannedValue<InlineFragmentsDeriveVariant>],
127    target_type: &InlineFragmentType<'_>,
128) -> Result<Option<Fallback>, Errors> {
129    let fallbacks = variants.iter().filter(|v| *v.fallback).collect::<Vec<_>>();
130
131    if fallbacks.is_empty() {
132        return Ok(None);
133    }
134
135    if fallbacks.len() > 1 {
136        let mut errors = Errors::default();
137        for fallback in &fallbacks[1..] {
138            errors.push(syn::Error::new(
139                fallback.span(),
140                "InlineFragments can't have more than one fallback",
141            ))
142        }
143
144        return Err(errors);
145    }
146
147    let fallback = fallbacks[0];
148
149    match fallback.fields.style {
150        darling::ast::Style::Struct => Err(syn::Error::new(
151            fallback.span(),
152            "InlineFragment fallbacks don't currently support struct variants",
153        )
154        .into()),
155        darling::ast::Style::Tuple => {
156            if fallback.fields.len() != 1 {
157                return Err(syn::Error::new(
158                    fallback.span(),
159                    "InlineFragments require variants to have one unnamed field",
160                )
161                .into());
162            }
163            Ok(Some(match target_type {
164                InlineFragmentType::Interface(_) => Fallback::InterfaceVariant(
165                    fallback.ident.clone(),
166                    fallback.fields.fields[0].ty.clone(),
167                ),
168                InlineFragmentType::Union(_) => Fallback::UnionVariantWithTypename(
169                    fallback.ident.clone(),
170                    fallback.fields.fields[0].ty.clone(),
171                ),
172            }))
173        }
174        darling::ast::Style::Unit => Ok(Some(Fallback::UnionUnitVariant(fallback.ident.clone()))),
175    }
176}
177
178struct QueryFragmentImpl<'a> {
179    target_enum: syn::Ident,
180    generics: &'a syn::Generics,
181    type_lock: syn::Path,
182    variables: Option<syn::Path>,
183    fragments: &'a [Fragment],
184    graphql_type_name: String,
185    fallback: Option<Fallback>,
186}
187
188impl quote::ToTokens for QueryFragmentImpl<'_> {
189    fn to_tokens(&self, tokens: &mut TokenStream) {
190        use quote::{quote, TokenStreamExt};
191
192        let target_struct = &self.target_enum;
193        let type_lock = &self.type_lock;
194        let variables_fields = variables_fields_path(self.variables.as_ref());
195        let variables_fields = match &variables_fields {
196            Some(path) => quote! { #path },
197            None => quote! { () },
198        };
199        let inner_types: Vec<_> = self
200            .fragments
201            .iter()
202            .map(|fragment| &fragment.inner_type)
203            .collect();
204        let graphql_type = proc_macro2::Literal::string(&self.graphql_type_name);
205        let fallback_selection = match &self.fallback {
206            Some(Fallback::InterfaceVariant(_, fallback_fragment)) => quote! {
207                <#fallback_fragment>::query(builder);
208            },
209            _ => quote! {},
210        };
211
212        let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
213
214        tokens.append_all(quote! {
215            #[automatically_derived]
216            impl #impl_generics cynic::QueryFragment for #target_struct #ty_generics #where_clause {
217                type SchemaType = #type_lock;
218                type VariablesFields = #variables_fields;
219
220                const TYPE: Option<&'static str> = Some(#graphql_type);
221
222                fn query(mut builder: cynic::queries::SelectionBuilder<'_, Self::SchemaType, Self::VariablesFields>) {
223                    #(
224                        let fragment_builder = builder.inline_fragment();
225                        let mut fragment_builder = fragment_builder.on::<<#inner_types as cynic::QueryFragment>::SchemaType>();
226                        <#inner_types as cynic::QueryFragment>::query(
227                            fragment_builder.select_children()
228                        );
229                    )*
230
231                    #fallback_selection
232                }
233            }
234        })
235    }
236}
237
238enum InlineFragmentType<'a> {
239    Union(UnionType<'a>),
240    Interface(InterfaceType<'a>),
241}
242
243impl InlineFragmentType<'_> {
244    pub fn name(&self) -> &str {
245        match self {
246            InlineFragmentType::Union(union_type) => &union_type.name,
247            InlineFragmentType::Interface(interface) => &interface.name,
248        }
249    }
250}
251
252impl<'a> InlineFragmentType<'a> {
253    pub fn marker_ident(&self) -> TypeMarkerIdent<'a> {
254        match self {
255            InlineFragmentType::Union(inner) => inner.marker_ident(),
256            InlineFragmentType::Interface(inner) => inner.marker_ident(),
257        }
258    }
259}
260
261impl<'a> TryFrom<Type<'a>> for InlineFragmentType<'a> {
262    type Error = SchemaError;
263
264    fn try_from(value: Type<'a>) -> Result<Self, Self::Error> {
265        match value {
266            Type::Interface(inner) => Ok(InlineFragmentType::Interface(inner)),
267            Type::Union(inner) => Ok(InlineFragmentType::Union(inner)),
268            _ => Err(SchemaError::unexpected_kind(value, Kind::UnionOrInterface)),
269        }
270    }
271}