cynic_codegen/inline_fragments_derive/
mod.rs1use {
2 darling::util::SpannedValue,
3 proc_macro2::{Span, TokenStream},
4};
5
6use crate::{
7 inline_fragments_derive::input::ValidationMode,
8 schema::{
9 markers::TypeMarkerIdent,
10 types::{InterfaceType, Kind, Type, UnionType},
11 Schema, SchemaError,
12 },
13 variables_fields_path, Errors,
14};
15
16pub mod input;
17
18mod exhaustiveness;
19mod inline_fragments_impl;
20
21#[cfg(test)]
22mod tests;
23
24pub use input::InlineFragmentsDeriveInput;
25
26use input::InlineFragmentsDeriveVariant;
27
28use self::inline_fragments_impl::Fallback;
29
30pub fn inline_fragments_derive(ast: &syn::DeriveInput) -> Result<TokenStream, Errors> {
31 use darling::FromDeriveInput;
32
33 match InlineFragmentsDeriveInput::from_derive_input(ast) {
34 Ok(input) => inline_fragments_derive_impl(input),
35 Err(e) => Ok(e.write_errors()),
36 }
37}
38
39pub(crate) fn inline_fragments_derive_impl(
40 input: InlineFragmentsDeriveInput,
41) -> Result<TokenStream, Errors> {
42 use quote::quote;
43
44 let schema = Schema::new(input.schema_input()?);
45
46 let target_type = schema.lookup::<InlineFragmentType<'_>>(&input.graphql_type_name())?;
47
48 input.validate(match target_type {
49 InlineFragmentType::Union(_) => ValidationMode::Union,
50 InlineFragmentType::Interface(_) => ValidationMode::Interface,
51 })?;
52
53 let variables = input.variables();
54
55 if let darling::ast::Data::Enum(variants) = &input.data {
56 if input.exhaustive.map(|e| *e).unwrap_or_default() {
57 exhaustiveness::exhaustiveness_check(variants, &target_type)?;
58 }
59
60 let fallback = check_fallback(variants, &target_type)?;
61
62 let type_lock = target_type.marker_ident().to_path(&input.schema_module());
63
64 let fragments = fragments_from_variants(variants)?;
65
66 let query_fragment_impl = QueryFragmentImpl {
67 target_enum: input.ident.clone(),
68 generics: &input.generics,
69 type_lock,
70 variables,
71 fragments: &fragments,
72 graphql_type_name: input.graphql_type_name(),
73 fallback: fallback.clone(),
74 };
75
76 let inline_fragments_impl = inline_fragments_impl::InlineFragmentsImpl {
77 target_enum: input.ident.clone(),
78 generics: &input.generics,
79 fragments: &fragments,
80 fallback,
81 };
82
83 Ok(quote! {
84 #inline_fragments_impl
85 #query_fragment_impl
86 })
87 } else {
88 Err(syn::Error::new(
89 Span::call_site(),
90 "InlineFragments can only be derived from an enum".to_string(),
91 )
92 .into())
93 }
94}
95
96struct Fragment {
97 rust_variant_name: syn::Ident,
98 inner_type: syn::Type,
99}
100
101fn fragments_from_variants(
102 variants: &[SpannedValue<InlineFragmentsDeriveVariant>],
103) -> Result<Vec<Fragment>, syn::Error> {
104 let mut result = vec![];
105 for variant in variants {
106 if *variant.fallback {
107 continue;
108 }
109
110 if variant.fields.style != darling::ast::Style::Tuple || variant.fields.fields.len() != 1 {
111 return Err(syn::Error::new(
112 variant.span(),
113 "InlineFragments derive requires enum variants to have one unnamed field",
114 ));
115 }
116 let field = variant.fields.fields.first().unwrap();
117 result.push(Fragment {
118 rust_variant_name: variant.ident.clone(),
119 inner_type: field.ty.clone(),
120 });
121 }
122 Ok(result)
123}
124
125fn check_fallback(
126 variants: &[SpannedValue<InlineFragmentsDeriveVariant>],
127 target_type: &InlineFragmentType<'_>,
128) -> Result<Option<Fallback>, Errors> {
129 let fallbacks = variants.iter().filter(|v| *v.fallback).collect::<Vec<_>>();
130
131 if fallbacks.is_empty() {
132 return Ok(None);
133 }
134
135 if fallbacks.len() > 1 {
136 let mut errors = Errors::default();
137 for fallback in &fallbacks[1..] {
138 errors.push(syn::Error::new(
139 fallback.span(),
140 "InlineFragments can't have more than one fallback",
141 ))
142 }
143
144 return Err(errors);
145 }
146
147 let fallback = fallbacks[0];
148
149 match fallback.fields.style {
150 darling::ast::Style::Struct => Err(syn::Error::new(
151 fallback.span(),
152 "InlineFragment fallbacks don't currently support struct variants",
153 )
154 .into()),
155 darling::ast::Style::Tuple => {
156 if fallback.fields.len() != 1 {
157 return Err(syn::Error::new(
158 fallback.span(),
159 "InlineFragments require variants to have one unnamed field",
160 )
161 .into());
162 }
163 Ok(Some(match target_type {
164 InlineFragmentType::Interface(_) => Fallback::InterfaceVariant(
165 fallback.ident.clone(),
166 fallback.fields.fields[0].ty.clone(),
167 ),
168 InlineFragmentType::Union(_) => Fallback::UnionVariantWithTypename(
169 fallback.ident.clone(),
170 fallback.fields.fields[0].ty.clone(),
171 ),
172 }))
173 }
174 darling::ast::Style::Unit => Ok(Some(Fallback::UnionUnitVariant(fallback.ident.clone()))),
175 }
176}
177
178struct QueryFragmentImpl<'a> {
179 target_enum: syn::Ident,
180 generics: &'a syn::Generics,
181 type_lock: syn::Path,
182 variables: Option<syn::Path>,
183 fragments: &'a [Fragment],
184 graphql_type_name: String,
185 fallback: Option<Fallback>,
186}
187
188impl quote::ToTokens for QueryFragmentImpl<'_> {
189 fn to_tokens(&self, tokens: &mut TokenStream) {
190 use quote::{quote, TokenStreamExt};
191
192 let target_struct = &self.target_enum;
193 let type_lock = &self.type_lock;
194 let variables_fields = variables_fields_path(self.variables.as_ref());
195 let variables_fields = match &variables_fields {
196 Some(path) => quote! { #path },
197 None => quote! { () },
198 };
199 let inner_types: Vec<_> = self
200 .fragments
201 .iter()
202 .map(|fragment| &fragment.inner_type)
203 .collect();
204 let graphql_type = proc_macro2::Literal::string(&self.graphql_type_name);
205 let fallback_selection = match &self.fallback {
206 Some(Fallback::InterfaceVariant(_, fallback_fragment)) => quote! {
207 <#fallback_fragment>::query(builder);
208 },
209 _ => quote! {},
210 };
211
212 let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
213
214 tokens.append_all(quote! {
215 #[automatically_derived]
216 impl #impl_generics cynic::QueryFragment for #target_struct #ty_generics #where_clause {
217 type SchemaType = #type_lock;
218 type VariablesFields = #variables_fields;
219
220 const TYPE: Option<&'static str> = Some(#graphql_type);
221
222 fn query(mut builder: cynic::queries::SelectionBuilder<'_, Self::SchemaType, Self::VariablesFields>) {
223 #(
224 let fragment_builder = builder.inline_fragment();
225 let mut fragment_builder = fragment_builder.on::<<#inner_types as cynic::QueryFragment>::SchemaType>();
226 <#inner_types as cynic::QueryFragment>::query(
227 fragment_builder.select_children()
228 );
229 )*
230
231 #fallback_selection
232 }
233 }
234 })
235 }
236}
237
238enum InlineFragmentType<'a> {
239 Union(UnionType<'a>),
240 Interface(InterfaceType<'a>),
241}
242
243impl InlineFragmentType<'_> {
244 pub fn name(&self) -> &str {
245 match self {
246 InlineFragmentType::Union(union_type) => &union_type.name,
247 InlineFragmentType::Interface(interface) => &interface.name,
248 }
249 }
250}
251
252impl<'a> InlineFragmentType<'a> {
253 pub fn marker_ident(&self) -> TypeMarkerIdent<'a> {
254 match self {
255 InlineFragmentType::Union(inner) => inner.marker_ident(),
256 InlineFragmentType::Interface(inner) => inner.marker_ident(),
257 }
258 }
259}
260
261impl<'a> TryFrom<Type<'a>> for InlineFragmentType<'a> {
262 type Error = SchemaError;
263
264 fn try_from(value: Type<'a>) -> Result<Self, Self::Error> {
265 match value {
266 Type::Interface(inner) => Ok(InlineFragmentType::Interface(inner)),
267 Type::Union(inner) => Ok(InlineFragmentType::Union(inner)),
268 _ => Err(SchemaError::unexpected_kind(value, Kind::UnionOrInterface)),
269 }
270 }
271}