Skip to main content

async_graphql/
schema.rs

1use std::{
2    any::{Any, TypeId},
3    collections::{HashMap, HashSet},
4    ops::Deref,
5    sync::Arc,
6};
7
8use async_graphql_parser::types::ExecutableDocument;
9use futures_util::stream::{self, BoxStream, FuturesOrdered, StreamExt};
10
11use crate::{
12    BatchRequest, BatchResponse, CacheControl, ContextBase, EmptyMutation, EmptySubscription,
13    Executor, InputType, ObjectType, OutputType, QueryEnv, Request, Response, ServerError,
14    ServerResult, SubscriptionType, Variables,
15    context::{Data, QueryEnvInner},
16    custom_directive::CustomDirectiveFactory,
17    extensions::{ExtensionFactory, Extensions},
18    parser::{
19        Positioned, parse_query,
20        types::{Directive, DocumentOperations, OperationType, Selection, SelectionSet},
21    },
22    registry::{Registry, SDLExportOptions},
23    resolver_utils::{resolve_container, resolve_container_serial},
24    subscription::collect_subscription_streams,
25    types::QueryRoot,
26    validation::{ValidationMode, check_rules},
27};
28
29/// Introspection mode
30#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
31pub enum IntrospectionMode {
32    /// Introspection only
33    IntrospectionOnly,
34    /// Enables introspection
35    #[default]
36    Enabled,
37    /// Disables introspection
38    Disabled,
39}
40
41/// Schema builder
42pub struct SchemaBuilder<Query, Mutation, Subscription> {
43    validation_mode: ValidationMode,
44    query: QueryRoot<Query>,
45    mutation: Mutation,
46    subscription: Subscription,
47    registry: Registry,
48    data: Data,
49    complexity: Option<usize>,
50    depth: Option<usize>,
51    recursive_depth: usize,
52    max_directives: Option<usize>,
53    max_aliases: Option<usize>,
54    extensions: Vec<Box<dyn ExtensionFactory>>,
55    custom_directives: HashMap<String, Box<dyn CustomDirectiveFactory>>,
56}
57
58impl<Query, Mutation, Subscription> SchemaBuilder<Query, Mutation, Subscription> {
59    /// Manually register a input type in the schema.
60    ///
61    /// You can use this function to register schema types that are not directly
62    /// referenced.
63    #[must_use]
64    pub fn register_input_type<T: InputType>(mut self) -> Self {
65        T::create_type_info(&mut self.registry);
66        self
67    }
68
69    /// Manually register a output type in the schema.
70    ///
71    /// You can use this function to register schema types that are not directly
72    /// referenced.
73    #[must_use]
74    pub fn register_output_type<T: OutputType>(mut self) -> Self {
75        T::create_type_info(&mut self.registry);
76        self
77    }
78
79    /// Disable introspection queries.
80    #[must_use]
81    pub fn disable_introspection(mut self) -> Self {
82        self.registry.introspection_mode = IntrospectionMode::Disabled;
83        self
84    }
85
86    /// Only process introspection queries, everything else is processed as an
87    /// error.
88    #[must_use]
89    pub fn introspection_only(mut self) -> Self {
90        self.registry.introspection_mode = IntrospectionMode::IntrospectionOnly;
91        self
92    }
93
94    /// Set the maximum complexity a query can have. By default, there is no
95    /// limit.
96    #[must_use]
97    pub fn limit_complexity(mut self, complexity: usize) -> Self {
98        self.complexity = Some(complexity);
99        self
100    }
101
102    /// Set the maximum amount of aliases a query can have. By default, there is
103    /// no limit.
104    #[must_use]
105    pub fn limit_aliases(mut self, max_aliases: usize) -> Self {
106        self.max_aliases = Some(max_aliases);
107        self
108    }
109
110    /// Set the maximum depth a query can have. By default, there is no limit.
111    #[must_use]
112    pub fn limit_depth(mut self, depth: usize) -> Self {
113        self.depth = Some(depth);
114        self
115    }
116
117    /// Set the maximum recursive depth a query can have. (default: 32)
118    ///
119    /// If the value is too large, stack overflow may occur, usually `32` is
120    /// enough.
121    #[must_use]
122    pub fn limit_recursive_depth(mut self, depth: usize) -> Self {
123        self.recursive_depth = depth;
124        self
125    }
126
127    /// Set the maximum number of directives on a single field. (default: no
128    /// limit)
129    pub fn limit_directives(mut self, max_directives: usize) -> Self {
130        self.max_directives = Some(max_directives);
131        self
132    }
133
134    /// Add an extension to the schema.
135    ///
136    /// # Examples
137    ///
138    /// ```rust
139    /// use async_graphql::*;
140    ///
141    /// struct Query;
142    ///
143    /// #[Object]
144    /// impl Query {
145    ///     async fn value(&self) -> i32 {
146    ///         100
147    ///     }
148    /// }
149    ///
150    /// let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
151    ///     .extension(extensions::Logger)
152    ///     .finish();
153    /// ```
154    #[must_use]
155    pub fn extension(mut self, extension: impl ExtensionFactory) -> Self {
156        self.extensions.push(Box::new(extension));
157        self
158    }
159
160    /// Add a global data that can be accessed in the `Schema`. You access it
161    /// with `Context::data`.
162    #[must_use]
163    pub fn data<D: Any + Send + Sync>(mut self, data: D) -> Self {
164        self.data.insert(data);
165        self
166    }
167
168    /// Set the validation mode, default is `ValidationMode::Strict`.
169    #[must_use]
170    pub fn validation_mode(mut self, validation_mode: ValidationMode) -> Self {
171        self.validation_mode = validation_mode;
172        self
173    }
174
175    /// Enable federation, which is automatically enabled if the Query has least
176    /// one entity definition.
177    #[must_use]
178    pub fn enable_federation(mut self) -> Self {
179        self.registry.enable_federation = true;
180        self
181    }
182
183    /// Make the Federation SDL include subscriptions.
184    ///
185    /// Note: Not included by default, in order to be compatible with Apollo
186    /// Server.
187    #[must_use]
188    pub fn enable_subscription_in_federation(mut self) -> Self {
189        self.registry.federation_subscription = true;
190        self
191    }
192
193    /// Override the name of the specified input type.
194    #[must_use]
195    pub fn override_input_type_description<T: InputType>(mut self, desc: &'static str) -> Self {
196        self.registry.set_description(&*T::type_name(), desc);
197        self
198    }
199
200    /// Override the name of the specified output type.
201    #[must_use]
202    pub fn override_output_type_description<T: OutputType>(mut self, desc: &'static str) -> Self {
203        self.registry.set_description(&*T::type_name(), desc);
204        self
205    }
206
207    /// Register a custom directive.
208    ///
209    /// # Panics
210    ///
211    /// Panics if the directive with the same name is already registered.
212    #[must_use]
213    pub fn directive<T: CustomDirectiveFactory>(mut self, directive: T) -> Self {
214        let name = directive.name();
215        let instance = Box::new(directive);
216
217        instance.register(&mut self.registry);
218
219        if name == "skip"
220            || name == "include"
221            || self
222                .custom_directives
223                .insert(name.clone().into(), instance)
224                .is_some()
225        {
226            panic!("Directive `{}` already exists", name);
227        }
228
229        self
230    }
231
232    /// Disable field suggestions.
233    #[must_use]
234    pub fn disable_suggestions(mut self) -> Self {
235        self.registry.enable_suggestions = false;
236        self
237    }
238
239    /// Make all fields sorted on introspection queries.
240    pub fn with_sorted_fields(mut self) -> Self {
241        use crate::registry::MetaType;
242        for ty in self.registry.types.values_mut() {
243            match ty {
244                MetaType::Object { fields, .. } | MetaType::Interface { fields, .. } => {
245                    fields.sort_keys();
246                }
247                MetaType::InputObject { input_fields, .. } => {
248                    input_fields.sort_keys();
249                }
250                MetaType::Scalar { .. } | MetaType::Enum { .. } | MetaType::Union { .. } => {
251                    // have no fields
252                }
253            }
254        }
255        self
256    }
257
258    /// Make all enum variants sorted on introspection queries.
259    pub fn with_sorted_enums(mut self) -> Self {
260        use crate::registry::MetaType;
261        for ty in &mut self.registry.types.values_mut() {
262            if let MetaType::Enum { enum_values, .. } = ty {
263                enum_values.sort_keys();
264            }
265        }
266        self
267    }
268
269    /// Consumes this builder and returns a schema.
270    pub fn finish(mut self) -> Schema<Query, Mutation, Subscription> {
271        // federation
272        if self.registry.enable_federation || self.registry.has_entities() {
273            self.registry.create_federation_types();
274        }
275
276        Schema(Arc::new(SchemaInner {
277            validation_mode: self.validation_mode,
278            query: self.query,
279            mutation: self.mutation,
280            subscription: self.subscription,
281            complexity: self.complexity,
282            depth: self.depth,
283            recursive_depth: self.recursive_depth,
284            max_directives: self.max_directives,
285            max_aliases: self.max_aliases,
286            extensions: self.extensions,
287            env: SchemaEnv(Arc::new(SchemaEnvInner {
288                registry: self.registry,
289                data: self.data,
290                custom_directives: self.custom_directives,
291            })),
292        }))
293    }
294}
295
296#[doc(hidden)]
297pub struct SchemaEnvInner {
298    pub registry: Registry,
299    pub data: Data,
300    pub custom_directives: HashMap<String, Box<dyn CustomDirectiveFactory>>,
301}
302
303#[doc(hidden)]
304#[derive(Clone)]
305pub struct SchemaEnv(pub(crate) Arc<SchemaEnvInner>);
306
307impl Deref for SchemaEnv {
308    type Target = SchemaEnvInner;
309
310    fn deref(&self) -> &Self::Target {
311        &self.0
312    }
313}
314
315#[doc(hidden)]
316pub struct SchemaInner<Query, Mutation, Subscription> {
317    pub(crate) validation_mode: ValidationMode,
318    pub(crate) query: QueryRoot<Query>,
319    pub(crate) mutation: Mutation,
320    pub(crate) subscription: Subscription,
321    pub(crate) complexity: Option<usize>,
322    pub(crate) depth: Option<usize>,
323    pub(crate) recursive_depth: usize,
324    pub(crate) max_directives: Option<usize>,
325    pub(crate) max_aliases: Option<usize>,
326    pub(crate) extensions: Vec<Box<dyn ExtensionFactory>>,
327    pub(crate) env: SchemaEnv,
328}
329
330/// GraphQL schema.
331///
332/// Cloning a schema is cheap, so it can be easily shared.
333pub struct Schema<Query, Mutation, Subscription>(
334    pub(crate) Arc<SchemaInner<Query, Mutation, Subscription>>,
335);
336
337impl<Query, Mutation, Subscription> Clone for Schema<Query, Mutation, Subscription> {
338    fn clone(&self) -> Self {
339        Schema(self.0.clone())
340    }
341}
342
343impl<Query, Mutation, Subscription> Default for Schema<Query, Mutation, Subscription>
344where
345    Query: Default + ObjectType + 'static,
346    Mutation: Default + ObjectType + 'static,
347    Subscription: Default + SubscriptionType + 'static,
348{
349    fn default() -> Self {
350        Schema::new(
351            Query::default(),
352            Mutation::default(),
353            Subscription::default(),
354        )
355    }
356}
357
358impl<Query, Mutation, Subscription> Schema<Query, Mutation, Subscription>
359where
360    Query: ObjectType + 'static,
361    Mutation: ObjectType + 'static,
362    Subscription: SubscriptionType + 'static,
363{
364    /// Create a schema builder
365    ///
366    /// The root object for the query and Mutation needs to be specified.
367    /// If there is no mutation, you can use `EmptyMutation`.
368    /// If there is no subscription, you can use `EmptySubscription`.
369    pub fn build(
370        query: Query,
371        mutation: Mutation,
372        subscription: Subscription,
373    ) -> SchemaBuilder<Query, Mutation, Subscription> {
374        Self::build_with_ignore_name_conflicts(query, mutation, subscription, [] as [&str; 0])
375    }
376
377    /// Create a schema builder and specifies a list to ignore type conflict
378    /// detection.
379    ///
380    /// NOTE: It is not recommended to use it unless you know what it does.
381    #[must_use]
382    pub fn build_with_ignore_name_conflicts<I, T>(
383        query: Query,
384        mutation: Mutation,
385        subscription: Subscription,
386        ignore_name_conflicts: I,
387    ) -> SchemaBuilder<Query, Mutation, Subscription>
388    where
389        I: IntoIterator<Item = T>,
390        T: Into<String>,
391    {
392        SchemaBuilder {
393            validation_mode: ValidationMode::Strict,
394            query: QueryRoot { inner: query },
395            mutation,
396            subscription,
397            registry: Self::create_registry(
398                ignore_name_conflicts.into_iter().map(Into::into).collect(),
399            ),
400            data: Default::default(),
401            complexity: None,
402            depth: None,
403            recursive_depth: 32,
404            max_directives: None,
405            max_aliases: None,
406            extensions: Default::default(),
407            custom_directives: Default::default(),
408        }
409    }
410
411    pub(crate) fn create_registry(ignore_name_conflicts: HashSet<String>) -> Registry {
412        let mut registry = Registry {
413            types: Default::default(),
414            directives: Default::default(),
415            implements: Default::default(),
416            query_type: Query::type_name().to_string(),
417            mutation_type: if Mutation::is_empty() {
418                None
419            } else {
420                Some(Mutation::type_name().to_string())
421            },
422            subscription_type: if Subscription::is_empty() {
423                None
424            } else {
425                Some(Subscription::type_name().to_string())
426            },
427            introspection_mode: IntrospectionMode::Enabled,
428            enable_federation: false,
429            federation_subscription: false,
430            ignore_name_conflicts,
431            enable_suggestions: true,
432        };
433        registry.add_system_types();
434
435        QueryRoot::<Query>::create_type_info(&mut registry);
436        if !Mutation::is_empty() {
437            Mutation::create_type_info(&mut registry);
438        }
439        if !Subscription::is_empty() {
440            Subscription::create_type_info(&mut registry);
441        }
442
443        registry.remove_unused_types();
444        registry
445    }
446
447    /// Create a schema
448    pub fn new(
449        query: Query,
450        mutation: Mutation,
451        subscription: Subscription,
452    ) -> Schema<Query, Mutation, Subscription> {
453        Self::build(query, mutation, subscription).finish()
454    }
455
456    #[inline]
457    #[allow(unused)]
458    pub(crate) fn registry(&self) -> &Registry {
459        &self.0.env.registry
460    }
461
462    /// Returns SDL(Schema Definition Language) of this schema.
463    pub fn sdl(&self) -> String {
464        self.0.env.registry.export_sdl(Default::default())
465    }
466
467    /// Returns SDL(Schema Definition Language) of this schema with options.
468    pub fn sdl_with_options(&self, options: SDLExportOptions) -> String {
469        self.0.env.registry.export_sdl(options)
470    }
471
472    /// Get all names in this schema
473    ///
474    /// Maybe you want to serialize a custom binary protocol. In order to
475    /// minimize message size, a dictionary is usually used to compress type
476    /// names, field names, directive names, and parameter names. This function
477    /// gets all the names, so you can create this dictionary.
478    pub fn names(&self) -> Vec<String> {
479        self.0.env.registry.names()
480    }
481
482    fn create_extensions(&self, session_data: Arc<Data>) -> Extensions {
483        Extensions::new(
484            self.0.extensions.iter().map(|f| f.create()),
485            self.0.env.clone(),
486            session_data,
487        )
488    }
489
490    async fn execute_once(&self, env: QueryEnv, execute_data: Option<&Data>) -> Response {
491        // execute
492        let ctx = ContextBase {
493            path_node: None,
494            is_for_introspection: false,
495            item: &env.operation.node.selection_set,
496            schema_env: &self.0.env,
497            query_env: &env,
498            execute_data,
499        };
500
501        let res = match &env.operation.node.ty {
502            OperationType::Query => resolve_container(&ctx, &self.0.query).await,
503            OperationType::Mutation => {
504                if self.0.env.registry.introspection_mode == IntrospectionMode::IntrospectionOnly
505                    || env.introspection_mode == IntrospectionMode::IntrospectionOnly
506                {
507                    resolve_container_serial(&ctx, &EmptyMutation).await
508                } else {
509                    resolve_container_serial(&ctx, &self.0.mutation).await
510                }
511            }
512            OperationType::Subscription => Err(ServerError::new(
513                "Subscriptions are not supported on this transport.",
514                None,
515            )),
516        };
517
518        let mut resp = match res {
519            Ok(value) => Response::new(value),
520            Err(err) => Response::from_errors(vec![err]),
521        }
522        .http_headers(std::mem::take(&mut *env.http_headers.lock().unwrap()));
523
524        resp.errors
525            .extend(std::mem::take(&mut *env.errors.lock().unwrap()));
526        resp
527    }
528
529    /// Execute a GraphQL query.
530    pub async fn execute(&self, request: impl Into<Request>) -> Response {
531        let request = request.into();
532        let extensions = self.create_extensions(Default::default());
533        let request_fut = {
534            let extensions = extensions.clone();
535            async move {
536                match prepare_request(
537                    extensions,
538                    request,
539                    Default::default(),
540                    &self.0.env.registry,
541                    self.0.validation_mode,
542                    self.0.recursive_depth,
543                    self.0.max_directives,
544                    self.0.max_aliases,
545                    self.0.complexity,
546                    self.0.depth,
547                )
548                .await
549                {
550                    Ok((env, cache_control)) => {
551                        let f = |execute_data: Option<Data>| {
552                            let env = env.clone();
553                            async move {
554                                self.execute_once(env, execute_data.as_ref())
555                                    .await
556                                    .cache_control(cache_control)
557                            }
558                        };
559                        env.extensions
560                            .execute(env.operation_name.as_deref(), f)
561                            .await
562                    }
563                    Err(errors) => Response::from_errors(errors),
564                }
565            }
566        };
567        futures_util::pin_mut!(request_fut);
568        extensions.request(&mut request_fut).await
569    }
570
571    /// Execute a GraphQL batch query.
572    pub async fn execute_batch(&self, batch_request: BatchRequest) -> BatchResponse {
573        match batch_request {
574            BatchRequest::Single(request) => BatchResponse::Single(self.execute(request).await),
575            BatchRequest::Batch(requests) => BatchResponse::Batch(
576                FuturesOrdered::from_iter(
577                    requests.into_iter().map(|request| self.execute(request)),
578                )
579                .collect()
580                .await,
581            ),
582        }
583    }
584
585    /// Execute a GraphQL subscription with session data.
586    pub fn execute_stream_with_session_data(
587        &self,
588        request: impl Into<Request>,
589        session_data: Arc<Data>,
590    ) -> BoxStream<'static, Response> {
591        let schema = self.clone();
592        let request = request.into();
593        let extensions = self.create_extensions(session_data.clone());
594
595        let stream = futures_util::stream::StreamExt::boxed({
596            let extensions = extensions.clone();
597            let env = self.0.env.clone();
598            asynk_strim::stream_fn(|mut yielder| async move {
599                let (env, cache_control) = match prepare_request(
600                    extensions,
601                    request,
602                    session_data,
603                    &env.registry,
604                    schema.0.validation_mode,
605                    schema.0.recursive_depth,
606                    schema.0.max_directives,
607                    schema.0.max_aliases,
608                    schema.0.complexity,
609                    schema.0.depth,
610                )
611                .await
612                {
613                    Ok(res) => res,
614                    Err(errors) => {
615                        yielder.yield_item(Response::from_errors(errors)).await;
616                        return;
617                    }
618                };
619
620                if env.operation.node.ty != OperationType::Subscription {
621                    let f = |execute_data: Option<Data>| {
622                        let env = env.clone();
623                        let schema = schema.clone();
624                        async move {
625                            schema
626                                .execute_once(env, execute_data.as_ref())
627                                .await
628                                .cache_control(cache_control)
629                        }
630                    };
631                    yielder
632                        .yield_item(
633                            env.extensions
634                                .execute(env.operation_name.as_deref(), f)
635                                .await
636                                .cache_control(cache_control),
637                        )
638                        .await;
639                    return;
640                }
641
642                let ctx = env.create_context(
643                    &schema.0.env,
644                    None,
645                    &env.operation.node.selection_set,
646                    None,
647                );
648
649                let mut streams = Vec::new();
650                let collect_result = if schema.0.env.registry.introspection_mode
651                    == IntrospectionMode::IntrospectionOnly
652                    || env.introspection_mode == IntrospectionMode::IntrospectionOnly
653                {
654                    collect_subscription_streams(&ctx, &EmptySubscription, &mut streams)
655                } else {
656                    collect_subscription_streams(&ctx, &schema.0.subscription, &mut streams)
657                };
658                if let Err(err) = collect_result {
659                    yielder.yield_item(Response::from_errors(vec![err])).await;
660                }
661
662                let mut stream = stream::select_all(streams);
663                while let Some(resp) = stream.next().await {
664                    yielder.yield_item(resp).await;
665                }
666            })
667        });
668        extensions.subscribe(stream)
669    }
670
671    /// Execute a GraphQL subscription.
672    pub fn execute_stream(&self, request: impl Into<Request>) -> BoxStream<'static, Response> {
673        self.execute_stream_with_session_data(request, Default::default())
674    }
675
676    /// Access global data stored in the Schema
677    pub fn data<D: Any + Send + Sync>(&self) -> Option<&D> {
678        self.0
679            .env
680            .data
681            .get(&TypeId::of::<D>())
682            .and_then(|d| d.downcast_ref::<D>())
683    }
684}
685
686#[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
687impl<Query, Mutation, Subscription> Executor for Schema<Query, Mutation, Subscription>
688where
689    Query: ObjectType + 'static,
690    Mutation: ObjectType + 'static,
691    Subscription: SubscriptionType + 'static,
692{
693    async fn execute(&self, request: Request) -> Response {
694        Schema::execute(self, request).await
695    }
696
697    fn execute_stream(
698        &self,
699        request: Request,
700        session_data: Option<Arc<Data>>,
701    ) -> BoxStream<'static, Response> {
702        Schema::execute_stream_with_session_data(&self, request, session_data.unwrap_or_default())
703    }
704}
705
706fn check_max_directives(doc: &ExecutableDocument, max_directives: usize) -> ServerResult<()> {
707    fn check_selection_set(
708        doc: &ExecutableDocument,
709        selection_set: &Positioned<SelectionSet>,
710        limit_directives: usize,
711    ) -> ServerResult<()> {
712        for selection in &selection_set.node.items {
713            match &selection.node {
714                Selection::Field(field) => {
715                    if field.node.directives.len() > limit_directives {
716                        return Err(ServerError::new(
717                            format!(
718                                "The number of directives on the field `{}` cannot be greater than `{}`",
719                                field.node.name.node, limit_directives
720                            ),
721                            Some(field.pos),
722                        ));
723                    }
724                    check_selection_set(doc, &field.node.selection_set, limit_directives)?;
725                }
726                Selection::FragmentSpread(fragment_spread) => {
727                    if let Some(fragment) =
728                        doc.fragments.get(&fragment_spread.node.fragment_name.node)
729                    {
730                        check_selection_set(doc, &fragment.node.selection_set, limit_directives)?;
731                    }
732                }
733                Selection::InlineFragment(inline_fragment) => {
734                    check_selection_set(
735                        doc,
736                        &inline_fragment.node.selection_set,
737                        limit_directives,
738                    )?;
739                }
740            }
741        }
742
743        Ok(())
744    }
745
746    for (_, operation) in doc.operations.iter() {
747        check_selection_set(doc, &operation.node.selection_set, max_directives)?;
748    }
749
750    Ok(())
751}
752
753fn check_alias_count(doc: &ExecutableDocument, max_aliases: usize) -> ServerResult<()> {
754    fn check_selection_set(
755        doc: &ExecutableDocument,
756        selection_set: &Positioned<SelectionSet>,
757        current_aliases: &mut usize,
758        max_aliases: usize,
759    ) -> ServerResult<()> {
760        for selection in &selection_set.node.items {
761            match &selection.node {
762                Selection::Field(field) => {
763                    if field.node.alias.is_some() {
764                        *current_aliases += 1;
765                    }
766                    if !field.node.selection_set.node.items.is_empty() {
767                        check_selection_set(
768                            doc,
769                            &field.node.selection_set,
770                            current_aliases,
771                            max_aliases,
772                        )?;
773                    }
774                }
775                Selection::FragmentSpread(fragment_spread) => {
776                    if let Some(fragment) =
777                        doc.fragments.get(&fragment_spread.node.fragment_name.node)
778                    {
779                        check_selection_set(
780                            doc,
781                            &fragment.node.selection_set,
782                            current_aliases,
783                            max_aliases,
784                        )?;
785                    }
786                }
787                Selection::InlineFragment(inline_fragment) => {
788                    check_selection_set(
789                        doc,
790                        &inline_fragment.node.selection_set,
791                        current_aliases,
792                        max_aliases,
793                    )?;
794                }
795            }
796
797            if *current_aliases > max_aliases {
798                return Err(ServerError::new(
799                    format!(
800                        "The amount of aliases of the query cannot be greater than `{}`",
801                        max_aliases
802                    ),
803                    Some(selection_set.pos),
804                ));
805            }
806        }
807
808        Ok(())
809    }
810
811    let mut current_aliases = 0;
812    for (_, operation) in doc.operations.iter() {
813        check_selection_set(
814            doc,
815            &operation.node.selection_set,
816            &mut current_aliases,
817            max_aliases,
818        )?;
819    }
820
821    Ok(())
822}
823
824fn check_recursive_depth(doc: &ExecutableDocument, max_depth: usize) -> ServerResult<()> {
825    fn check_selection_set(
826        doc: &ExecutableDocument,
827        selection_set: &Positioned<SelectionSet>,
828        current_depth: usize,
829        max_depth: usize,
830    ) -> ServerResult<()> {
831        if current_depth > max_depth {
832            return Err(ServerError::new(
833                format!(
834                    "The recursion depth of the query cannot be greater than `{}`",
835                    max_depth
836                ),
837                Some(selection_set.pos),
838            ));
839        }
840
841        for selection in &selection_set.node.items {
842            match &selection.node {
843                Selection::Field(field) => {
844                    if !field.node.selection_set.node.items.is_empty() {
845                        check_selection_set(
846                            doc,
847                            &field.node.selection_set,
848                            current_depth + 1,
849                            max_depth,
850                        )?;
851                    }
852                }
853                Selection::FragmentSpread(fragment_spread) => {
854                    if let Some(fragment) =
855                        doc.fragments.get(&fragment_spread.node.fragment_name.node)
856                    {
857                        check_selection_set(
858                            doc,
859                            &fragment.node.selection_set,
860                            current_depth + 1,
861                            max_depth,
862                        )?;
863                    }
864                }
865                Selection::InlineFragment(inline_fragment) => {
866                    check_selection_set(
867                        doc,
868                        &inline_fragment.node.selection_set,
869                        current_depth + 1,
870                        max_depth,
871                    )?;
872                }
873            }
874        }
875
876        Ok(())
877    }
878
879    for (_, operation) in doc.operations.iter() {
880        check_selection_set(doc, &operation.node.selection_set, 0, max_depth)?;
881    }
882
883    Ok(())
884}
885
886fn remove_skipped_selection(selection_set: &mut SelectionSet, variables: &Variables) {
887    fn is_skipped(directives: &[Positioned<Directive>], variables: &Variables) -> bool {
888        for directive in directives {
889            let include = match &*directive.node.name.node {
890                "skip" => false,
891                "include" => true,
892                _ => continue,
893            };
894
895            if let Some(condition_input) = directive.node.get_argument("if") {
896                let value = condition_input
897                    .node
898                    .clone()
899                    .into_const_with(|name| variables.get(&name).cloned().ok_or(()))
900                    .unwrap_or_default();
901                let value: bool = InputType::parse(Some(value)).unwrap_or_default();
902                if include != value {
903                    return true;
904                }
905            }
906        }
907
908        false
909    }
910
911    selection_set
912        .items
913        .retain(|selection| !is_skipped(selection.node.directives(), variables));
914
915    for selection in &mut selection_set.items {
916        selection.node.directives_mut().retain(|directive| {
917            directive.node.name.node != "skip" && directive.node.name.node != "include"
918        });
919    }
920
921    for selection in &mut selection_set.items {
922        match &mut selection.node {
923            Selection::Field(field) => {
924                remove_skipped_selection(&mut field.node.selection_set.node, variables);
925            }
926            Selection::FragmentSpread(_) => {}
927            Selection::InlineFragment(inline_fragment) => {
928                remove_skipped_selection(&mut inline_fragment.node.selection_set.node, variables);
929            }
930        }
931    }
932}
933
934#[allow(clippy::too_many_arguments)]
935pub(crate) async fn prepare_request(
936    mut extensions: Extensions,
937    request: Request,
938    session_data: Arc<Data>,
939    registry: &Registry,
940    validation_mode: ValidationMode,
941    recursive_depth: usize,
942    max_directives: Option<usize>,
943    max_aliases: Option<usize>,
944    complexity: Option<usize>,
945    depth: Option<usize>,
946) -> Result<(QueryEnv, CacheControl), Vec<ServerError>> {
947    let mut request = extensions.prepare_request(request).await?;
948    let query_data = Arc::new(std::mem::take(&mut request.data));
949    extensions.attach_query_data(query_data.clone());
950
951    let mut document = {
952        let query = &request.query;
953        let parsed_doc = request.parsed_query.take();
954        let fut_parse = async move {
955            let doc = match parsed_doc {
956                Some(parsed_doc) => parsed_doc,
957                None => parse_query(query)?,
958            };
959            check_recursive_depth(&doc, recursive_depth)?;
960            if let Some(max_directives) = max_directives {
961                check_max_directives(&doc, max_directives)?;
962            }
963
964            if let Some(max_aliases) = max_aliases {
965                check_alias_count(&doc, max_aliases)?;
966            }
967
968            Ok(doc)
969        };
970        futures_util::pin_mut!(fut_parse);
971        extensions
972            .parse_query(query, &request.variables, &mut fut_parse)
973            .await?
974    };
975
976    // check rules
977    let validation_result = {
978        let validation_fut = async {
979            check_rules(
980                registry,
981                &document,
982                Some(&request.variables),
983                request.operation_name.as_deref(),
984                validation_mode,
985                complexity,
986                depth,
987            )
988        };
989        futures_util::pin_mut!(validation_fut);
990        extensions.validation(&mut validation_fut).await?
991    };
992
993    let operation = if let Some(operation_name) = &request.operation_name {
994        match document.operations {
995            DocumentOperations::Single(_) => None,
996            DocumentOperations::Multiple(mut operations) => operations
997                .remove(operation_name.as_str())
998                .map(|operation| (Some(operation_name.clone()), operation)),
999        }
1000        .ok_or_else(|| {
1001            ServerError::new(
1002                format!(r#"Unknown operation named "{}""#, operation_name),
1003                None,
1004            )
1005        })
1006    } else {
1007        match document.operations {
1008            DocumentOperations::Single(operation) => Ok((None, operation)),
1009            DocumentOperations::Multiple(map) if map.len() == 1 => {
1010                let (operation_name, operation) = map.into_iter().next().unwrap();
1011                Ok((Some(operation_name.to_string()), operation))
1012            }
1013            DocumentOperations::Multiple(_) => Err(ServerError::new(
1014                "Operation name required in request.",
1015                None,
1016            )),
1017        }
1018    };
1019
1020    let (operation_name, mut operation) = operation.map_err(|err| vec![err])?;
1021
1022    // remove skipped fields
1023    for fragment in document.fragments.values_mut() {
1024        remove_skipped_selection(&mut fragment.node.selection_set.node, &request.variables);
1025    }
1026    remove_skipped_selection(&mut operation.node.selection_set.node, &request.variables);
1027
1028    let env = QueryEnvInner {
1029        extensions,
1030        variables: request.variables,
1031        operation_name,
1032        operation,
1033        fragments: document.fragments,
1034        uploads: request.uploads,
1035        session_data,
1036        query_data,
1037        http_headers: Default::default(),
1038        introspection_mode: request.introspection_mode,
1039        errors: Default::default(),
1040    };
1041    Ok((QueryEnv::new(env), validation_result.cache_control))
1042}