async_graphql_derive/
interface.rs

1use std::collections::HashSet;
2
3use darling::ast::{Data, Style};
4use proc_macro::TokenStream;
5use proc_macro2::{Ident, Span};
6use quote::quote;
7use syn::{Error, Type, visit_mut::VisitMut};
8
9use crate::{
10    args::{
11        self, InterfaceField, InterfaceFieldArgument, RenameRuleExt, RenameTarget,
12        TypeDirectiveLocation,
13    },
14    output_type::OutputType,
15    utils::{
16        GeneratorResult, RemoveLifetime, gen_boxed_trait, gen_deprecation, gen_directive_calls,
17        generate_default, get_crate_path, get_rustdoc, visible_fn,
18    },
19};
20
21pub fn generate(interface_args: &args::Interface) -> GeneratorResult<TokenStream> {
22    let crate_name = get_crate_path(&interface_args.crate_path, interface_args.internal);
23    let boxed_trait = gen_boxed_trait(&crate_name);
24    let ident = &interface_args.ident;
25    let type_params = interface_args.generics.type_params().collect::<Vec<_>>();
26    let (impl_generics, ty_generics, where_clause) = interface_args.generics.split_for_impl();
27    let s = match &interface_args.data {
28        Data::Enum(s) => s,
29        _ => {
30            return Err(
31                Error::new_spanned(ident, "Interface can only be applied to an enum.").into(),
32            );
33        }
34    };
35    let extends = interface_args.extends;
36    let mut enum_names = Vec::new();
37    let mut enum_items = HashSet::new();
38    let mut type_into_impls = Vec::new();
39    let inaccessible = interface_args.inaccessible;
40    let tags = interface_args
41        .tags
42        .iter()
43        .map(|tag| quote!(::std::string::ToString::to_string(#tag)))
44        .collect::<Vec<_>>();
45    let requires_scopes = interface_args
46        .requires_scopes
47        .iter()
48        .map(|scopes| quote!(::std::string::ToString::to_string(#scopes)))
49        .collect::<Vec<_>>();
50
51    let directives = gen_directive_calls(
52        &crate_name,
53        &interface_args.directives,
54        TypeDirectiveLocation::Interface,
55    );
56    let gql_typename = if !interface_args.name_type {
57        let name = interface_args
58            .name
59            .clone()
60            .unwrap_or_else(|| RenameTarget::Type.rename(ident.to_string()));
61        quote!(::std::borrow::Cow::Borrowed(#name))
62    } else {
63        quote!(<Self as #crate_name::TypeName>::type_name())
64    };
65
66    let desc = get_rustdoc(&interface_args.attrs)?
67        .map(|s| quote! { ::std::option::Option::Some(::std::string::ToString::to_string(#s)) })
68        .unwrap_or_else(|| quote! {::std::option::Option::None});
69
70    let mut registry_types = Vec::new();
71    let mut possible_types = Vec::new();
72    let mut get_introspection_typename = Vec::new();
73    let mut collect_all_fields = Vec::new();
74
75    for variant in s {
76        let enum_name = &variant.ident;
77        let ty = match variant.fields.style {
78            Style::Tuple if variant.fields.fields.len() == 1 => &variant.fields.fields[0],
79            Style::Tuple => {
80                return Err(Error::new_spanned(
81                    enum_name,
82                    "Only single value variants are supported",
83                )
84                .into());
85            }
86            Style::Unit => {
87                return Err(
88                    Error::new_spanned(enum_name, "Empty variants are not supported").into(),
89                );
90            }
91            Style::Struct => {
92                return Err(Error::new_spanned(
93                    enum_name,
94                    "Variants with named fields are not supported",
95                )
96                .into());
97            }
98        };
99
100        if let Type::Path(p) = ty {
101            // This validates that the field type wasn't already used
102            if !enum_items.insert(p) {
103                return Err(
104                    Error::new_spanned(ty, "This type already used in another variant").into(),
105                );
106            }
107
108            let mut assert_ty = p.clone();
109            RemoveLifetime.visit_type_path_mut(&mut assert_ty);
110
111            type_into_impls.push(quote! {
112                #crate_name::static_assertions_next::assert_impl!(for(#(#type_params),*) #assert_ty: (#crate_name::ObjectType) | (#crate_name::InterfaceType));
113
114                #[allow(clippy::all, clippy::pedantic)]
115                impl #impl_generics ::std::convert::From<#p> for #ident #ty_generics #where_clause {
116                    fn from(obj: #p) -> Self {
117                        #ident::#enum_name(obj)
118                    }
119                }
120            });
121            enum_names.push(enum_name);
122
123            registry_types.push(quote! {
124                <#p as #crate_name::OutputType>::create_type_info(registry);
125                registry.add_implements(&<#p as #crate_name::OutputType>::type_name(), ::std::convert::AsRef::as_ref(&#gql_typename));
126            });
127
128            possible_types.push(quote! {
129                possible_types.insert(<#p as #crate_name::OutputType>::type_name().into_owned());
130            });
131
132            get_introspection_typename.push(quote! {
133                #ident::#enum_name(obj) => <#p as #crate_name::OutputType>::type_name()
134            });
135
136            collect_all_fields.push(quote! {
137                #ident::#enum_name(obj) => obj.collect_all_fields(ctx, fields)
138            });
139        } else {
140            return Err(Error::new_spanned(ty, "Invalid type").into());
141        }
142    }
143
144    let mut methods = Vec::new();
145    let mut schema_fields = Vec::new();
146    let mut resolvers = Vec::new();
147
148    if interface_args.fields.is_empty() {
149        return Err(Error::new_spanned(
150            ident,
151            "A GraphQL Interface type must define one or more fields.",
152        )
153        .into());
154    }
155
156    for InterfaceField {
157        name,
158        method,
159        desc,
160        ty,
161        args,
162        deprecation,
163        external,
164        provides,
165        requires,
166        visible,
167        shareable,
168        inaccessible,
169        tags,
170        override_from,
171        directives,
172        requires_scopes,
173    } in &interface_args.fields
174    {
175        let (name, method_name) = if let Some(method) = method {
176            (name.to_string(), Ident::new_raw(method, Span::call_site()))
177        } else {
178            let method_name = Ident::new_raw(name, Span::call_site());
179            (
180                interface_args
181                    .rename_fields
182                    .rename(name.as_ref(), RenameTarget::Field),
183                method_name,
184            )
185        };
186        let mut calls = Vec::new();
187        let mut use_params = Vec::new();
188        let mut decl_params = Vec::new();
189        let mut get_params = Vec::new();
190        let mut schema_args = Vec::new();
191        let requires = match &requires {
192            Some(requires) => {
193                quote! { ::std::option::Option::Some(::std::string::ToString::to_string(#requires)) }
194            }
195            None => quote! { ::std::option::Option::None },
196        };
197        let provides = match &provides {
198            Some(provides) => {
199                quote! { ::std::option::Option::Some(::std::string::ToString::to_string(#provides)) }
200            }
201            None => quote! { ::std::option::Option::None },
202        };
203        let override_from = match &override_from {
204            Some(from) => {
205                quote! { ::std::option::Option::Some(::std::string::ToString::to_string(#from)) }
206            }
207            None => quote! { ::std::option::Option::None },
208        };
209
210        decl_params.push(quote! { ctx: &'ctx #crate_name::Context<'ctx> });
211        use_params.push(quote! { ctx });
212
213        for (
214            i,
215            InterfaceFieldArgument {
216                name,
217                desc,
218                ty,
219                default,
220                default_with,
221                visible,
222                inaccessible,
223                tags,
224                secret,
225                directives,
226                deprecation,
227            },
228        ) in args.iter().enumerate()
229        {
230            let ident = Ident::new(&format!("arg{}", i), Span::call_site());
231            let name = interface_args
232                .rename_args
233                .rename(name, RenameTarget::Argument);
234            decl_params.push(quote! { #ident: #ty });
235            use_params.push(quote! { #ident });
236
237            let default = generate_default(default, default_with)?;
238            let get_default = match &default {
239                Some(default) => quote! { ::std::option::Option::Some(|| -> #ty { #default }) },
240                None => quote! { ::std::option::Option::None },
241            };
242            get_params.push(quote! {
243                let (_, #ident) = ctx.param_value::<#ty>(#name, #get_default)?;
244            });
245
246            let desc = desc
247                .as_ref()
248                .map(|s| quote! {::std::option::Option::Some(::std::string::ToString::to_string(#s))})
249                .unwrap_or_else(|| quote! {::std::option::Option::None});
250            let schema_default = default
251                .as_ref()
252                .map(|value| {
253                    quote! {
254                        ::std::option::Option::Some(::std::string::ToString::to_string(
255                            &<#ty as #crate_name::InputType>::to_value(&#value)
256                        ))
257                    }
258                })
259                .unwrap_or_else(|| quote! {::std::option::Option::None});
260            let visible = visible_fn(visible);
261            let tags = tags
262                .iter()
263                .map(|tag| quote!(::std::string::ToString::to_string(#tag)))
264                .collect::<Vec<_>>();
265            let directives = gen_directive_calls(
266                &crate_name,
267                directives,
268                TypeDirectiveLocation::ArgumentDefinition,
269            );
270            let deprecation = gen_deprecation(deprecation, &crate_name);
271
272            schema_args.push(quote! {
273                    args.insert(::std::borrow::ToOwned::to_owned(#name), #crate_name::registry::MetaInputValue {
274                        name: ::std::string::ToString::to_string(#name),
275                        description: #desc,
276                        ty: <#ty as #crate_name::InputType>::create_type_info(registry),
277                        deprecation: #deprecation,
278                        default_value: #schema_default,
279                        visible: #visible,
280                        inaccessible: #inaccessible,
281                        tags: ::std::vec![ #(#tags),* ],
282                        is_secret: #secret,
283                        directive_invocations: ::std::vec![ #(#directives),* ],
284                    });
285                });
286        }
287
288        for enum_name in &enum_names {
289            calls.push(quote! {
290                #ident::#enum_name(obj) => obj.#method_name(#(#use_params),*)
291                    .await.map_err(|err| ::std::convert::Into::<#crate_name::Error>::into(err))
292                    .map(::std::convert::Into::into)
293            });
294        }
295
296        let desc = desc
297            .as_ref()
298            .map(|s| quote! {::std::option::Option::Some(::std::string::ToString::to_string(#s))})
299            .unwrap_or_else(|| quote! {::std::option::Option::None});
300        let deprecation = gen_deprecation(deprecation, &crate_name);
301
302        let oty = OutputType::parse(ty)?;
303        let ty = match oty {
304            OutputType::Value(ty) => ty,
305            OutputType::Result(ty) => ty,
306        };
307        let schema_ty = oty.value_type();
308
309        methods.push(quote! {
310            #[allow(missing_docs)]
311            #[inline]
312            pub async fn #method_name<'ctx>(&self, #(#decl_params),*) -> #crate_name::Result<#ty> {
313                match self {
314                    #(#calls,)*
315                }
316            }
317        });
318
319        let visible = visible_fn(visible);
320        let tags = tags
321            .iter()
322            .map(|tag| quote!(::std::string::ToString::to_string(#tag)))
323            .collect::<Vec<_>>();
324        let requires_scopes = requires_scopes
325            .iter()
326            .map(|scopes| quote!(::std::string::ToString::to_string(#scopes)))
327            .collect::<Vec<_>>();
328
329        let directives = gen_directive_calls(
330            &crate_name,
331            directives,
332            TypeDirectiveLocation::FieldDefinition,
333        );
334
335        schema_fields.push(quote! {
336            fields.insert(::std::string::ToString::to_string(#name), #crate_name::registry::MetaField {
337                name: ::std::string::ToString::to_string(#name),
338                description: #desc,
339                args: {
340                    let mut args = #crate_name::indexmap::IndexMap::new();
341                    #(#schema_args)*
342                    args
343                },
344                ty: <#schema_ty as #crate_name::OutputType>::create_type_info(registry),
345                deprecation: #deprecation,
346                cache_control: ::std::default::Default::default(),
347                external: #external,
348                provides: #provides,
349                requires: #requires,
350                shareable: #shareable,
351                inaccessible: #inaccessible,
352                tags: ::std::vec![ #(#tags),* ],
353                override_from: #override_from,
354                visible: #visible,
355                compute_complexity: ::std::option::Option::None,
356                directive_invocations: ::std::vec![ #(#directives),* ],
357                requires_scopes: ::std::vec![ #(#requires_scopes),* ],
358            });
359        });
360
361        let resolve_obj = quote! {
362            self.#method_name(#(#use_params),*)
363                .await
364                .map_err(|err| ::std::convert::Into::<#crate_name::Error>::into(err).into_server_error(ctx.item.pos))?
365        };
366
367        resolvers.push(quote! {
368            if ctx.item.node.name.node == #name {
369                #(#get_params)*
370                let ctx_obj = ctx.with_selection_set(&ctx.item.node.selection_set);
371                return #crate_name::OutputType::resolve(&#resolve_obj, &ctx_obj, ctx.item).await.map(::std::option::Option::Some);
372            }
373        });
374    }
375
376    let introspection_type_name = if get_introspection_typename.is_empty() {
377        quote! { ::std::unreachable!() }
378    } else {
379        quote! {
380            match self {
381            #(#get_introspection_typename),*
382            }
383        }
384    };
385
386    let visible = visible_fn(&interface_args.visible);
387    let expanded = quote! {
388        #(#type_into_impls)*
389
390        #[allow(clippy::all, clippy::pedantic)]
391        impl #impl_generics #ident #ty_generics #where_clause {
392            #(#methods)*
393        }
394
395        #[allow(clippy::all, clippy::pedantic)]
396        #boxed_trait
397        impl #impl_generics #crate_name::resolver_utils::ContainerType for #ident #ty_generics #where_clause {
398            async fn resolve_field(&self, ctx: &#crate_name::Context<'_>) -> #crate_name::ServerResult<::std::option::Option<#crate_name::Value>> {
399                #(#resolvers)*
400                ::std::result::Result::Ok(::std::option::Option::None)
401            }
402
403            fn collect_all_fields<'__life>(&'__life self, ctx: &#crate_name::ContextSelectionSet<'__life>, fields: &mut #crate_name::resolver_utils::Fields<'__life>) -> #crate_name::ServerResult<()> {
404                match self {
405                    #(#collect_all_fields),*
406                }
407            }
408        }
409
410        #[allow(clippy::all, clippy::pedantic)]
411        #boxed_trait
412        impl #impl_generics #crate_name::OutputType for #ident #ty_generics #where_clause {
413            fn type_name() -> ::std::borrow::Cow<'static, ::std::primitive::str> {
414                #gql_typename
415            }
416
417            fn introspection_type_name(&self) -> ::std::borrow::Cow<'static, ::std::primitive::str> {
418                #introspection_type_name
419            }
420
421            fn create_type_info(registry: &mut #crate_name::registry::Registry) -> ::std::string::String {
422                registry.create_output_type::<Self, _>(#crate_name::registry::MetaTypeId::Interface, |registry| {
423                    #(#registry_types)*
424
425                    #crate_name::registry::MetaType::Interface {
426                        name: ::std::borrow::Cow::into_owned(#gql_typename),
427                        description: #desc,
428                        fields: {
429                            let mut fields = #crate_name::indexmap::IndexMap::new();
430                            #(#schema_fields)*
431                            fields
432                        },
433                        possible_types: {
434                            let mut possible_types = #crate_name::indexmap::IndexSet::new();
435                            #(#possible_types)*
436                            possible_types
437                        },
438                        extends: #extends,
439                        keys: ::std::option::Option::None,
440                        visible: #visible,
441                        inaccessible: #inaccessible,
442                        tags: ::std::vec![ #(#tags),* ],
443                        rust_typename: ::std::option::Option::Some(::std::any::type_name::<Self>()),
444                        directive_invocations: ::std::vec![ #(#directives),* ],
445                        requires_scopes: ::std::vec![ #(#requires_scopes),* ],
446                    }
447                })
448            }
449
450            async fn resolve(
451                &self,
452                ctx: &#crate_name::ContextSelectionSet<'_>,
453                _field: &#crate_name::Positioned<#crate_name::parser::types::Field>,
454            ) -> #crate_name::ServerResult<#crate_name::Value> {
455                #crate_name::resolver_utils::resolve_container(ctx, self).await
456            }
457        }
458
459        impl #impl_generics #crate_name::InterfaceType for #ident #ty_generics #where_clause {}
460    };
461    Ok(expanded.into())
462}