async_graphql_derive/
interface.rs

1use std::collections::HashSet;
2
3use darling::ast::{Data, Style};
4use proc_macro::TokenStream;
5use proc_macro2::{Ident, Span};
6use quote::quote;
7use syn::{Error, Type, visit_mut::VisitMut};
8
9use crate::{
10    args::{
11        self, InterfaceField, InterfaceFieldArgument, RenameRuleExt, RenameTarget,
12        TypeDirectiveLocation,
13    },
14    output_type::OutputType,
15    utils::{
16        GeneratorResult, RemoveLifetime, gen_boxed_trait, gen_deprecation, gen_directive_calls,
17        generate_default, get_crate_name, get_rustdoc, visible_fn,
18    },
19};
20
21pub fn generate(interface_args: &args::Interface) -> GeneratorResult<TokenStream> {
22    let crate_name = get_crate_name(interface_args.internal);
23    let boxed_trait = gen_boxed_trait(&crate_name);
24    let ident = &interface_args.ident;
25    let type_params = interface_args.generics.type_params().collect::<Vec<_>>();
26    let (impl_generics, ty_generics, where_clause) = interface_args.generics.split_for_impl();
27    let s = match &interface_args.data {
28        Data::Enum(s) => s,
29        _ => {
30            return Err(
31                Error::new_spanned(ident, "Interface can only be applied to an enum.").into(),
32            );
33        }
34    };
35    let extends = interface_args.extends;
36    let mut enum_names = Vec::new();
37    let mut enum_items = HashSet::new();
38    let mut type_into_impls = Vec::new();
39    let inaccessible = interface_args.inaccessible;
40    let tags = interface_args
41        .tags
42        .iter()
43        .map(|tag| quote!(::std::string::ToString::to_string(#tag)))
44        .collect::<Vec<_>>();
45    let requires_scopes = interface_args
46        .requires_scopes
47        .iter()
48        .map(|scopes| quote!(::std::string::ToString::to_string(#scopes)))
49        .collect::<Vec<_>>();
50    let directives =
51        gen_directive_calls(&interface_args.directives, TypeDirectiveLocation::Interface);
52    let gql_typename = if !interface_args.name_type {
53        let name = interface_args
54            .name
55            .clone()
56            .unwrap_or_else(|| RenameTarget::Type.rename(ident.to_string()));
57        quote!(::std::borrow::Cow::Borrowed(#name))
58    } else {
59        quote!(<Self as #crate_name::TypeName>::type_name())
60    };
61
62    let desc = get_rustdoc(&interface_args.attrs)?
63        .map(|s| quote! { ::std::option::Option::Some(::std::string::ToString::to_string(#s)) })
64        .unwrap_or_else(|| quote! {::std::option::Option::None});
65
66    let mut registry_types = Vec::new();
67    let mut possible_types = Vec::new();
68    let mut get_introspection_typename = Vec::new();
69    let mut collect_all_fields = Vec::new();
70
71    for variant in s {
72        let enum_name = &variant.ident;
73        let ty = match variant.fields.style {
74            Style::Tuple if variant.fields.fields.len() == 1 => &variant.fields.fields[0],
75            Style::Tuple => {
76                return Err(Error::new_spanned(
77                    enum_name,
78                    "Only single value variants are supported",
79                )
80                .into());
81            }
82            Style::Unit => {
83                return Err(
84                    Error::new_spanned(enum_name, "Empty variants are not supported").into(),
85                );
86            }
87            Style::Struct => {
88                return Err(Error::new_spanned(
89                    enum_name,
90                    "Variants with named fields are not supported",
91                )
92                .into());
93            }
94        };
95
96        if let Type::Path(p) = ty {
97            // This validates that the field type wasn't already used
98            if !enum_items.insert(p) {
99                return Err(
100                    Error::new_spanned(ty, "This type already used in another variant").into(),
101                );
102            }
103
104            let mut assert_ty = p.clone();
105            RemoveLifetime.visit_type_path_mut(&mut assert_ty);
106
107            type_into_impls.push(quote! {
108                #crate_name::static_assertions_next::assert_impl!(for(#(#type_params),*) #assert_ty: (#crate_name::ObjectType) | (#crate_name::InterfaceType));
109
110                #[allow(clippy::all, clippy::pedantic)]
111                impl #impl_generics ::std::convert::From<#p> for #ident #ty_generics #where_clause {
112                    fn from(obj: #p) -> Self {
113                        #ident::#enum_name(obj)
114                    }
115                }
116            });
117            enum_names.push(enum_name);
118
119            registry_types.push(quote! {
120                <#p as #crate_name::OutputType>::create_type_info(registry);
121                registry.add_implements(&<#p as #crate_name::OutputType>::type_name(), ::std::convert::AsRef::as_ref(&#gql_typename));
122            });
123
124            possible_types.push(quote! {
125                possible_types.insert(<#p as #crate_name::OutputType>::type_name().into_owned());
126            });
127
128            get_introspection_typename.push(quote! {
129                #ident::#enum_name(obj) => <#p as #crate_name::OutputType>::type_name()
130            });
131
132            collect_all_fields.push(quote! {
133                #ident::#enum_name(obj) => obj.collect_all_fields(ctx, fields)
134            });
135        } else {
136            return Err(Error::new_spanned(ty, "Invalid type").into());
137        }
138    }
139
140    let mut methods = Vec::new();
141    let mut schema_fields = Vec::new();
142    let mut resolvers = Vec::new();
143
144    if interface_args.fields.is_empty() {
145        return Err(Error::new_spanned(
146            ident,
147            "A GraphQL Interface type must define one or more fields.",
148        )
149        .into());
150    }
151
152    for InterfaceField {
153        name,
154        method,
155        desc,
156        ty,
157        args,
158        deprecation,
159        external,
160        provides,
161        requires,
162        visible,
163        shareable,
164        inaccessible,
165        tags,
166        override_from,
167        directives,
168        requires_scopes,
169    } in &interface_args.fields
170    {
171        let (name, method_name) = if let Some(method) = method {
172            (name.to_string(), Ident::new_raw(method, Span::call_site()))
173        } else {
174            let method_name = Ident::new_raw(name, Span::call_site());
175            (
176                interface_args
177                    .rename_fields
178                    .rename(name.as_ref(), RenameTarget::Field),
179                method_name,
180            )
181        };
182        let mut calls = Vec::new();
183        let mut use_params = Vec::new();
184        let mut decl_params = Vec::new();
185        let mut get_params = Vec::new();
186        let mut schema_args = Vec::new();
187        let requires = match &requires {
188            Some(requires) => {
189                quote! { ::std::option::Option::Some(::std::string::ToString::to_string(#requires)) }
190            }
191            None => quote! { ::std::option::Option::None },
192        };
193        let provides = match &provides {
194            Some(provides) => {
195                quote! { ::std::option::Option::Some(::std::string::ToString::to_string(#provides)) }
196            }
197            None => quote! { ::std::option::Option::None },
198        };
199        let override_from = match &override_from {
200            Some(from) => {
201                quote! { ::std::option::Option::Some(::std::string::ToString::to_string(#from)) }
202            }
203            None => quote! { ::std::option::Option::None },
204        };
205
206        decl_params.push(quote! { ctx: &'ctx #crate_name::Context<'ctx> });
207        use_params.push(quote! { ctx });
208
209        for (
210            i,
211            InterfaceFieldArgument {
212                name,
213                desc,
214                ty,
215                default,
216                default_with,
217                visible,
218                inaccessible,
219                tags,
220                secret,
221                directives,
222                deprecation,
223            },
224        ) in args.iter().enumerate()
225        {
226            let ident = Ident::new(&format!("arg{}", i), Span::call_site());
227            let name = interface_args
228                .rename_args
229                .rename(name, RenameTarget::Argument);
230            decl_params.push(quote! { #ident: #ty });
231            use_params.push(quote! { #ident });
232
233            let default = generate_default(default, default_with)?;
234            let get_default = match &default {
235                Some(default) => quote! { ::std::option::Option::Some(|| -> #ty { #default }) },
236                None => quote! { ::std::option::Option::None },
237            };
238            get_params.push(quote! {
239                let (_, #ident) = ctx.param_value::<#ty>(#name, #get_default)?;
240            });
241
242            let desc = desc
243                .as_ref()
244                .map(|s| quote! {::std::option::Option::Some(::std::string::ToString::to_string(#s))})
245                .unwrap_or_else(|| quote! {::std::option::Option::None});
246            let schema_default = default
247                .as_ref()
248                .map(|value| {
249                    quote! {
250                        ::std::option::Option::Some(::std::string::ToString::to_string(
251                            &<#ty as #crate_name::InputType>::to_value(&#value)
252                        ))
253                    }
254                })
255                .unwrap_or_else(|| quote! {::std::option::Option::None});
256            let visible = visible_fn(visible);
257            let tags = tags
258                .iter()
259                .map(|tag| quote!(::std::string::ToString::to_string(#tag)))
260                .collect::<Vec<_>>();
261            let directives =
262                gen_directive_calls(directives, TypeDirectiveLocation::ArgumentDefinition);
263            let deprecation = gen_deprecation(deprecation, &crate_name);
264
265            schema_args.push(quote! {
266                    args.insert(::std::borrow::ToOwned::to_owned(#name), #crate_name::registry::MetaInputValue {
267                        name: ::std::string::ToString::to_string(#name),
268                        description: #desc,
269                        ty: <#ty as #crate_name::InputType>::create_type_info(registry),
270                        deprecation: #deprecation,
271                        default_value: #schema_default,
272                        visible: #visible,
273                        inaccessible: #inaccessible,
274                        tags: ::std::vec![ #(#tags),* ],
275                        is_secret: #secret,
276                        directive_invocations: ::std::vec![ #(#directives),* ],
277                    });
278                });
279        }
280
281        for enum_name in &enum_names {
282            calls.push(quote! {
283                #ident::#enum_name(obj) => obj.#method_name(#(#use_params),*)
284                    .await.map_err(|err| ::std::convert::Into::<#crate_name::Error>::into(err))
285                    .map(::std::convert::Into::into)
286            });
287        }
288
289        let desc = desc
290            .as_ref()
291            .map(|s| quote! {::std::option::Option::Some(::std::string::ToString::to_string(#s))})
292            .unwrap_or_else(|| quote! {::std::option::Option::None});
293        let deprecation = gen_deprecation(deprecation, &crate_name);
294
295        let oty = OutputType::parse(ty)?;
296        let ty = match oty {
297            OutputType::Value(ty) => ty,
298            OutputType::Result(ty) => ty,
299        };
300        let schema_ty = oty.value_type();
301
302        methods.push(quote! {
303            #[inline]
304            pub async fn #method_name<'ctx>(&self, #(#decl_params),*) -> #crate_name::Result<#ty> {
305                match self {
306                    #(#calls,)*
307                }
308            }
309        });
310
311        let visible = visible_fn(visible);
312        let tags = tags
313            .iter()
314            .map(|tag| quote!(::std::string::ToString::to_string(#tag)))
315            .collect::<Vec<_>>();
316        let requires_scopes = requires_scopes
317            .iter()
318            .map(|scopes| quote!(::std::string::ToString::to_string(#scopes)))
319            .collect::<Vec<_>>();
320        let directives = gen_directive_calls(directives, TypeDirectiveLocation::FieldDefinition);
321
322        schema_fields.push(quote! {
323            fields.insert(::std::string::ToString::to_string(#name), #crate_name::registry::MetaField {
324                name: ::std::string::ToString::to_string(#name),
325                description: #desc,
326                args: {
327                    let mut args = #crate_name::indexmap::IndexMap::new();
328                    #(#schema_args)*
329                    args
330                },
331                ty: <#schema_ty as #crate_name::OutputType>::create_type_info(registry),
332                deprecation: #deprecation,
333                cache_control: ::std::default::Default::default(),
334                external: #external,
335                provides: #provides,
336                requires: #requires,
337                shareable: #shareable,
338                inaccessible: #inaccessible,
339                tags: ::std::vec![ #(#tags),* ],
340                override_from: #override_from,
341                visible: #visible,
342                compute_complexity: ::std::option::Option::None,
343                directive_invocations: ::std::vec![ #(#directives),* ],
344                requires_scopes: ::std::vec![ #(#requires_scopes),* ],
345            });
346        });
347
348        let resolve_obj = quote! {
349            self.#method_name(#(#use_params),*)
350                .await
351                .map_err(|err| ::std::convert::Into::<#crate_name::Error>::into(err).into_server_error(ctx.item.pos))?
352        };
353
354        resolvers.push(quote! {
355            if ctx.item.node.name.node == #name {
356                #(#get_params)*
357                let ctx_obj = ctx.with_selection_set(&ctx.item.node.selection_set);
358                return #crate_name::OutputType::resolve(&#resolve_obj, &ctx_obj, ctx.item).await.map(::std::option::Option::Some);
359            }
360        });
361    }
362
363    let introspection_type_name = if get_introspection_typename.is_empty() {
364        quote! { ::std::unreachable!() }
365    } else {
366        quote! {
367            match self {
368            #(#get_introspection_typename),*
369            }
370        }
371    };
372
373    let visible = visible_fn(&interface_args.visible);
374    let expanded = quote! {
375        #(#type_into_impls)*
376
377        #[allow(clippy::all, clippy::pedantic)]
378        impl #impl_generics #ident #ty_generics #where_clause {
379            #(#methods)*
380        }
381
382        #[allow(clippy::all, clippy::pedantic)]
383        #boxed_trait
384        impl #impl_generics #crate_name::resolver_utils::ContainerType for #ident #ty_generics #where_clause {
385            async fn resolve_field(&self, ctx: &#crate_name::Context<'_>) -> #crate_name::ServerResult<::std::option::Option<#crate_name::Value>> {
386                #(#resolvers)*
387                ::std::result::Result::Ok(::std::option::Option::None)
388            }
389
390            fn collect_all_fields<'__life>(&'__life self, ctx: &#crate_name::ContextSelectionSet<'__life>, fields: &mut #crate_name::resolver_utils::Fields<'__life>) -> #crate_name::ServerResult<()> {
391                match self {
392                    #(#collect_all_fields),*
393                }
394            }
395        }
396
397        #[allow(clippy::all, clippy::pedantic)]
398        #boxed_trait
399        impl #impl_generics #crate_name::OutputType for #ident #ty_generics #where_clause {
400            fn type_name() -> ::std::borrow::Cow<'static, ::std::primitive::str> {
401                #gql_typename
402            }
403
404            fn introspection_type_name(&self) -> ::std::borrow::Cow<'static, ::std::primitive::str> {
405                #introspection_type_name
406            }
407
408            fn create_type_info(registry: &mut #crate_name::registry::Registry) -> ::std::string::String {
409                registry.create_output_type::<Self, _>(#crate_name::registry::MetaTypeId::Interface, |registry| {
410                    #(#registry_types)*
411
412                    #crate_name::registry::MetaType::Interface {
413                        name: ::std::borrow::Cow::into_owned(#gql_typename),
414                        description: #desc,
415                        fields: {
416                            let mut fields = #crate_name::indexmap::IndexMap::new();
417                            #(#schema_fields)*
418                            fields
419                        },
420                        possible_types: {
421                            let mut possible_types = #crate_name::indexmap::IndexSet::new();
422                            #(#possible_types)*
423                            possible_types
424                        },
425                        extends: #extends,
426                        keys: ::std::option::Option::None,
427                        visible: #visible,
428                        inaccessible: #inaccessible,
429                        tags: ::std::vec![ #(#tags),* ],
430                        rust_typename: ::std::option::Option::Some(::std::any::type_name::<Self>()),
431                        directive_invocations: ::std::vec![ #(#directives),* ],
432                        requires_scopes: ::std::vec![ #(#requires_scopes),* ],
433                    }
434                })
435            }
436
437            async fn resolve(
438                &self,
439                ctx: &#crate_name::ContextSelectionSet<'_>,
440                _field: &#crate_name::Positioned<#crate_name::parser::types::Field>,
441            ) -> #crate_name::ServerResult<#crate_name::Value> {
442                #crate_name::resolver_utils::resolve_container(ctx, self).await
443            }
444        }
445
446        impl #impl_generics #crate_name::InterfaceType for #ident #ty_generics #where_clause {}
447    };
448    Ok(expanded.into())
449}