async_graphql/dynamic/
schema.rs

1use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc};
2
3use async_graphql_parser::types::OperationType;
4use futures_util::{StreamExt, TryFutureExt, stream::BoxStream};
5use indexmap::IndexMap;
6
7use crate::{
8    Data, Executor, IntrospectionMode, QueryEnv, Request, Response, SDLExportOptions, SchemaEnv,
9    ServerError, ServerResult, ValidationMode,
10    dynamic::{
11        DynamicRequest, FieldFuture, FieldValue, Object, ResolverContext, Scalar, SchemaError,
12        Subscription, TypeRef, Union, field::BoxResolverFn, resolve::resolve_container,
13        r#type::Type,
14    },
15    extensions::{ExtensionFactory, Extensions},
16    registry::{MetaType, Registry},
17    schema::{SchemaEnvInner, prepare_request},
18};
19
20/// Dynamic schema builder
21pub struct SchemaBuilder {
22    query_type: String,
23    mutation_type: Option<String>,
24    subscription_type: Option<String>,
25    types: IndexMap<String, Type>,
26    data: Data,
27    extensions: Vec<Box<dyn ExtensionFactory>>,
28    validation_mode: ValidationMode,
29    recursive_depth: usize,
30    max_directives: Option<usize>,
31    complexity: Option<usize>,
32    depth: Option<usize>,
33    enable_suggestions: bool,
34    introspection_mode: IntrospectionMode,
35    enable_federation: bool,
36    entity_resolver: Option<BoxResolverFn>,
37}
38
39impl SchemaBuilder {
40    /// Register a GraphQL type
41    #[must_use]
42    pub fn register(mut self, ty: impl Into<Type>) -> Self {
43        let ty = ty.into();
44        self.types.insert(ty.name().to_string(), ty);
45        self
46    }
47
48    /// Enable uploading files (register Upload type).
49    pub fn enable_uploading(mut self) -> Self {
50        self.types.insert(TypeRef::UPLOAD.to_string(), Type::Upload);
51        self
52    }
53
54    /// Add a global data that can be accessed in the `Schema`. You access it
55    /// with `Context::data`.
56    #[must_use]
57    pub fn data<D: Any + Send + Sync>(mut self, data: D) -> Self {
58        self.data.insert(data);
59        self
60    }
61
62    /// Add an extension to the schema.
63    #[must_use]
64    pub fn extension(mut self, extension: impl ExtensionFactory) -> Self {
65        self.extensions.push(Box::new(extension));
66        self
67    }
68
69    /// Set the maximum complexity a query can have. By default, there is no
70    /// limit.
71    #[must_use]
72    pub fn limit_complexity(mut self, complexity: usize) -> Self {
73        self.complexity = Some(complexity);
74        self
75    }
76
77    /// Set the maximum depth a query can have. By default, there is no limit.
78    #[must_use]
79    pub fn limit_depth(mut self, depth: usize) -> Self {
80        self.depth = Some(depth);
81        self
82    }
83
84    /// Set the maximum recursive depth a query can have. (default: 32)
85    ///
86    /// If the value is too large, stack overflow may occur, usually `32` is
87    /// enough.
88    #[must_use]
89    pub fn limit_recursive_depth(mut self, depth: usize) -> Self {
90        self.recursive_depth = depth;
91        self
92    }
93
94    /// Set the maximum number of directives on a single field. (default: no
95    /// limit)
96    pub fn limit_directives(mut self, max_directives: usize) -> Self {
97        self.max_directives = Some(max_directives);
98        self
99    }
100
101    /// Set the validation mode, default is `ValidationMode::Strict`.
102    #[must_use]
103    pub fn validation_mode(mut self, validation_mode: ValidationMode) -> Self {
104        self.validation_mode = validation_mode;
105        self
106    }
107
108    /// Disable field suggestions.
109    #[must_use]
110    pub fn disable_suggestions(mut self) -> Self {
111        self.enable_suggestions = false;
112        self
113    }
114
115    /// Disable introspection queries.
116    #[must_use]
117    pub fn disable_introspection(mut self) -> Self {
118        self.introspection_mode = IntrospectionMode::Disabled;
119        self
120    }
121
122    /// Only process introspection queries, everything else is processed as an
123    /// error.
124    #[must_use]
125    pub fn introspection_only(mut self) -> Self {
126        self.introspection_mode = IntrospectionMode::IntrospectionOnly;
127        self
128    }
129
130    /// Enable federation, which is automatically enabled if the Query has least
131    /// one entity definition.
132    #[must_use]
133    pub fn enable_federation(mut self) -> Self {
134        self.enable_federation = true;
135        self
136    }
137
138    /// Set the entity resolver for federation
139    pub fn entity_resolver<F>(self, resolver_fn: F) -> Self
140    where
141        F: for<'a> Fn(ResolverContext<'a>) -> FieldFuture<'a> + Send + Sync + 'static,
142    {
143        Self {
144            entity_resolver: Some(Box::new(resolver_fn)),
145            ..self
146        }
147    }
148
149    /// Consumes this builder and returns a schema.
150    pub fn finish(mut self) -> Result<Schema, SchemaError> {
151        let mut registry = Registry {
152            types: Default::default(),
153            directives: Default::default(),
154            implements: Default::default(),
155            query_type: self.query_type,
156            mutation_type: self.mutation_type,
157            subscription_type: self.subscription_type,
158            introspection_mode: self.introspection_mode,
159            enable_federation: false,
160            federation_subscription: false,
161            ignore_name_conflicts: Default::default(),
162            enable_suggestions: self.enable_suggestions,
163        };
164        registry.add_system_types();
165
166        for ty in self.types.values() {
167            ty.register(&mut registry)?;
168        }
169        update_interface_possible_types(&mut self.types, &mut registry);
170
171        // create system scalars
172        for ty in ["Int", "Float", "Boolean", "String", "ID"] {
173            self.types
174                .insert(ty.to_string(), Type::Scalar(Scalar::new(ty)));
175        }
176
177        // create introspection types
178        if matches!(
179            self.introspection_mode,
180            IntrospectionMode::Enabled | IntrospectionMode::IntrospectionOnly
181        ) {
182            registry.create_introspection_types();
183        }
184
185        // create entity types
186        if self.enable_federation || registry.has_entities() {
187            registry.enable_federation = true;
188            registry.create_federation_types();
189
190            // create _Entity type
191            let entity = self
192                .types
193                .values()
194                .filter(|ty| match ty {
195                    Type::Object(obj) => obj.is_entity(),
196                    Type::Interface(interface) => interface.is_entity(),
197                    _ => false,
198                })
199                .fold(Union::new("_Entity"), |entity, ty| {
200                    entity.possible_type(ty.name())
201                });
202            self.types
203                .insert("_Entity".to_string(), Type::Union(entity));
204        }
205
206        let inner = SchemaInner {
207            env: SchemaEnv(Arc::new(SchemaEnvInner {
208                registry,
209                data: self.data,
210                custom_directives: Default::default(),
211            })),
212            extensions: self.extensions,
213            types: self.types,
214            recursive_depth: self.recursive_depth,
215            max_directives: self.max_directives,
216            complexity: self.complexity,
217            depth: self.depth,
218            validation_mode: self.validation_mode,
219            entity_resolver: self.entity_resolver,
220        };
221        inner.check()?;
222        Ok(Schema(Arc::new(inner)))
223    }
224}
225
226/// Dynamic GraphQL schema.
227///
228/// Cloning a schema is cheap, so it can be easily shared.
229#[derive(Clone)]
230pub struct Schema(pub(crate) Arc<SchemaInner>);
231
232impl Debug for Schema {
233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234        f.debug_struct("Schema").finish()
235    }
236}
237
238pub struct SchemaInner {
239    pub(crate) env: SchemaEnv,
240    pub(crate) types: IndexMap<String, Type>,
241    extensions: Vec<Box<dyn ExtensionFactory>>,
242    recursive_depth: usize,
243    max_directives: Option<usize>,
244    complexity: Option<usize>,
245    depth: Option<usize>,
246    validation_mode: ValidationMode,
247    pub(crate) entity_resolver: Option<BoxResolverFn>,
248}
249
250impl Schema {
251    /// Create a schema builder
252    pub fn build(query: &str, mutation: Option<&str>, subscription: Option<&str>) -> SchemaBuilder {
253        SchemaBuilder {
254            query_type: query.to_string(),
255            mutation_type: mutation.map(ToString::to_string),
256            subscription_type: subscription.map(ToString::to_string),
257            types: Default::default(),
258            data: Default::default(),
259            extensions: Default::default(),
260            validation_mode: ValidationMode::Strict,
261            recursive_depth: 32,
262            max_directives: None,
263            complexity: None,
264            depth: None,
265            enable_suggestions: true,
266            introspection_mode: IntrospectionMode::Enabled,
267            entity_resolver: None,
268            enable_federation: false,
269        }
270    }
271
272    fn create_extensions(&self, session_data: Arc<Data>) -> Extensions {
273        Extensions::new(
274            self.0.extensions.iter().map(|f| f.create()),
275            self.0.env.clone(),
276            session_data,
277        )
278    }
279
280    fn query_root(&self) -> ServerResult<&Object> {
281        self.0
282            .types
283            .get(&self.0.env.registry.query_type)
284            .and_then(Type::as_object)
285            .ok_or_else(|| ServerError::new("Query root not found", None))
286    }
287
288    fn mutation_root(&self) -> ServerResult<&Object> {
289        self.0
290            .env
291            .registry
292            .mutation_type
293            .as_ref()
294            .and_then(|mutation_name| self.0.types.get(mutation_name))
295            .and_then(Type::as_object)
296            .ok_or_else(|| ServerError::new("Mutation root not found", None))
297    }
298
299    fn subscription_root(&self) -> ServerResult<&Subscription> {
300        self.0
301            .env
302            .registry
303            .subscription_type
304            .as_ref()
305            .and_then(|subscription_name| self.0.types.get(subscription_name))
306            .and_then(Type::as_subscription)
307            .ok_or_else(|| ServerError::new("Subscription root not found", None))
308    }
309
310    /// Returns SDL(Schema Definition Language) of this schema.
311    pub fn sdl(&self) -> String {
312        self.0.env.registry.export_sdl(Default::default())
313    }
314
315    /// Returns SDL(Schema Definition Language) of this schema with options.
316    pub fn sdl_with_options(&self, options: SDLExportOptions) -> String {
317        self.0.env.registry.export_sdl(options)
318    }
319
320    async fn execute_once(
321        &self,
322        env: QueryEnv,
323        root_value: &FieldValue<'static>,
324        execute_data: Option<Data>,
325    ) -> Response {
326        // execute
327        let ctx = env.create_context(
328            &self.0.env,
329            None,
330            &env.operation.node.selection_set,
331            execute_data.as_ref(),
332        );
333        let res = match &env.operation.node.ty {
334            OperationType::Query => {
335                async move { self.query_root() }
336                    .and_then(|query_root| {
337                        resolve_container(self, query_root, &ctx, root_value, false)
338                    })
339                    .await
340            }
341            OperationType::Mutation => {
342                async move { self.mutation_root() }
343                    .and_then(|query_root| {
344                        resolve_container(self, query_root, &ctx, root_value, true)
345                    })
346                    .await
347            }
348            OperationType::Subscription => Err(ServerError::new(
349                "Subscriptions are not supported on this transport.",
350                None,
351            )),
352        };
353
354        let mut resp = match res {
355            Ok(value) => Response::new(value.unwrap_or_default()),
356            Err(err) => Response::from_errors(vec![err]),
357        }
358        .http_headers(std::mem::take(&mut *env.http_headers.lock().unwrap()));
359
360        resp.errors
361            .extend(std::mem::take(&mut *env.errors.lock().unwrap()));
362        resp
363    }
364
365    /// Execute a GraphQL query.
366    pub async fn execute(&self, request: impl Into<DynamicRequest>) -> Response {
367        let request = request.into();
368        let extensions = self.create_extensions(Default::default());
369        let request_fut = {
370            let extensions = extensions.clone();
371            async move {
372                match prepare_request(
373                    extensions,
374                    request.inner,
375                    Default::default(),
376                    &self.0.env.registry,
377                    self.0.validation_mode,
378                    self.0.recursive_depth,
379                    self.0.max_directives,
380                    self.0.complexity,
381                    self.0.depth,
382                )
383                .await
384                {
385                    Ok((env, cache_control)) => {
386                        let f = {
387                            |execute_data| {
388                                let env = env.clone();
389                                async move {
390                                    self.execute_once(env, &request.root_value, execute_data)
391                                        .await
392                                        .cache_control(cache_control)
393                                }
394                            }
395                        };
396                        env.extensions
397                            .execute(env.operation_name.as_deref(), f)
398                            .await
399                    }
400                    Err(errors) => Response::from_errors(errors),
401                }
402            }
403        };
404        futures_util::pin_mut!(request_fut);
405        extensions.request(&mut request_fut).await
406    }
407
408    /// Execute a GraphQL subscription with session data.
409    pub fn execute_stream_with_session_data(
410        &self,
411        request: impl Into<DynamicRequest>,
412        session_data: Arc<Data>,
413    ) -> BoxStream<'static, Response> {
414        let schema = self.clone();
415        let request = request.into();
416        let extensions = self.create_extensions(session_data.clone());
417
418        let stream = {
419            let extensions = extensions.clone();
420
421            async_stream::stream! {
422                let subscription = match schema.subscription_root() {
423                    Ok(subscription) => subscription,
424                    Err(err) => {
425                        yield Response::from_errors(vec![err]);
426                        return;
427                    }
428                };
429
430                let (env, _) = match prepare_request(
431                    extensions,
432                    request.inner,
433                    session_data,
434                    &schema.0.env.registry,
435                    schema.0.validation_mode,
436                    schema.0.recursive_depth,
437                    schema.0.max_directives,
438                    schema.0.complexity,
439                    schema.0.depth,
440                )
441                .await {
442                    Ok(res) => res,
443                    Err(errors) => {
444                        yield Response::from_errors(errors);
445                        return;
446                    }
447                };
448
449                if env.operation.node.ty != OperationType::Subscription {
450                    yield schema.execute_once(env, &request.root_value, None).await;
451                    return;
452                }
453
454                let ctx = env.create_context(
455                    &schema.0.env,
456                    None,
457                    &env.operation.node.selection_set,
458                    None,
459                );
460                let mut streams = Vec::new();
461                subscription.collect_streams(&schema, &ctx, &mut streams, &request.root_value);
462
463                let mut stream = futures_util::stream::select_all(streams);
464                while let Some(resp) = stream.next().await {
465                    yield resp;
466                }
467            }
468        };
469        extensions.subscribe(stream.boxed())
470    }
471
472    /// Execute a GraphQL subscription.
473    pub fn execute_stream(
474        &self,
475        request: impl Into<DynamicRequest>,
476    ) -> BoxStream<'static, Response> {
477        self.execute_stream_with_session_data(request, Default::default())
478    }
479
480    /// Returns the registry of this schema.
481    pub fn registry(&self) -> &Registry {
482        &self.0.env.registry
483    }
484}
485
486#[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
487impl Executor for Schema {
488    async fn execute(&self, request: Request) -> Response {
489        Schema::execute(self, request).await
490    }
491
492    fn execute_stream(
493        &self,
494        request: Request,
495        session_data: Option<Arc<Data>>,
496    ) -> BoxStream<'static, Response> {
497        Schema::execute_stream_with_session_data(self, request, session_data.unwrap_or_default())
498    }
499}
500
501fn update_interface_possible_types(types: &mut IndexMap<String, Type>, registry: &mut Registry) {
502    let mut interfaces = registry
503        .types
504        .values_mut()
505        .filter_map(|ty| match ty {
506            MetaType::Interface {
507                name,
508                possible_types,
509                ..
510            } => Some((name, possible_types)),
511            _ => None,
512        })
513        .collect::<HashMap<_, _>>();
514
515    let objs = types.values().filter_map(|ty| match ty {
516        Type::Object(obj) => Some((&obj.name, &obj.implements)),
517        _ => None,
518    });
519
520    for (obj_name, implements) in objs {
521        for interface in implements {
522            if let Some(possible_types) = interfaces.get_mut(interface) {
523                possible_types.insert(obj_name.clone());
524            }
525        }
526    }
527}
528
529#[cfg(test)]
530mod tests {
531    use std::sync::Arc;
532
533    use async_graphql_parser::{Pos, types::ExecutableDocument};
534    use async_graphql_value::Variables;
535    use futures_util::{StreamExt, stream::BoxStream};
536    use tokio::sync::Mutex;
537
538    use crate::{
539        PathSegment, Request, Response, ServerError, ServerResult, ValidationResult, Value,
540        dynamic::*, extensions::*, value,
541    };
542
543    #[tokio::test]
544    async fn basic_query() {
545        let myobj = Object::new("MyObj")
546            .field(Field::new("a", TypeRef::named(TypeRef::INT), |_| {
547                FieldFuture::new(async { Ok(Some(Value::from(123))) })
548            }))
549            .field(Field::new("b", TypeRef::named(TypeRef::STRING), |_| {
550                FieldFuture::new(async { Ok(Some(Value::from("abc"))) })
551            }));
552
553        let query = Object::new("Query")
554            .field(Field::new("value", TypeRef::named(TypeRef::INT), |_| {
555                FieldFuture::new(async { Ok(Some(Value::from(100))) })
556            }))
557            .field(Field::new(
558                "valueObj",
559                TypeRef::named_nn(myobj.type_name()),
560                |_| FieldFuture::new(async { Ok(Some(FieldValue::NULL)) }),
561            ));
562        let schema = Schema::build("Query", None, None)
563            .register(query)
564            .register(myobj)
565            .finish()
566            .unwrap();
567
568        assert_eq!(
569            schema
570                .execute("{ value valueObj { a b } }")
571                .await
572                .into_result()
573                .unwrap()
574                .data,
575            value!({
576                "value": 100,
577                "valueObj": {
578                    "a": 123,
579                    "b": "abc",
580                }
581            })
582        );
583    }
584
585    #[tokio::test]
586    async fn root_value() {
587        let query =
588            Object::new("Query").field(Field::new("value", TypeRef::named(TypeRef::INT), |ctx| {
589                FieldFuture::new(async {
590                    Ok(Some(Value::Number(
591                        (*ctx.parent_value.try_downcast_ref::<i32>()?).into(),
592                    )))
593                })
594            }));
595
596        let schema = Schema::build("Query", None, None)
597            .register(query)
598            .finish()
599            .unwrap();
600        assert_eq!(
601            schema
602                .execute("{ value }".root_value(FieldValue::owned_any(100)))
603                .await
604                .into_result()
605                .unwrap()
606                .data,
607            value!({ "value": 100, })
608        );
609    }
610
611    #[tokio::test]
612    async fn field_alias() {
613        let query =
614            Object::new("Query").field(Field::new("value", TypeRef::named(TypeRef::INT), |_| {
615                FieldFuture::new(async { Ok(Some(Value::from(100))) })
616            }));
617        let schema = Schema::build("Query", None, None)
618            .register(query)
619            .finish()
620            .unwrap();
621
622        assert_eq!(
623            schema
624                .execute("{ a: value }")
625                .await
626                .into_result()
627                .unwrap()
628                .data,
629            value!({
630                "a": 100,
631            })
632        );
633    }
634
635    #[tokio::test]
636    async fn fragment_spread() {
637        let myobj = Object::new("MyObj")
638            .field(Field::new("a", TypeRef::named(TypeRef::INT), |_| {
639                FieldFuture::new(async { Ok(Some(Value::from(123))) })
640            }))
641            .field(Field::new("b", TypeRef::named(TypeRef::STRING), |_| {
642                FieldFuture::new(async { Ok(Some(Value::from("abc"))) })
643            }));
644
645        let query = Object::new("Query").field(Field::new(
646            "valueObj",
647            TypeRef::named_nn(myobj.type_name()),
648            |_| FieldFuture::new(async { Ok(Some(Value::Null)) }),
649        ));
650        let schema = Schema::build("Query", None, None)
651            .register(query)
652            .register(myobj)
653            .finish()
654            .unwrap();
655
656        let query = r#"
657            fragment A on MyObj {
658                a b
659            }
660
661            { valueObj { ... A } }
662            "#;
663
664        assert_eq!(
665            schema.execute(query).await.into_result().unwrap().data,
666            value!({
667                "valueObj": {
668                    "a": 123,
669                    "b": "abc",
670                }
671            })
672        );
673    }
674
675    #[tokio::test]
676    async fn inline_fragment() {
677        let myobj = Object::new("MyObj")
678            .field(Field::new("a", TypeRef::named(TypeRef::INT), |_| {
679                FieldFuture::new(async { Ok(Some(Value::from(123))) })
680            }))
681            .field(Field::new("b", TypeRef::named(TypeRef::STRING), |_| {
682                FieldFuture::new(async { Ok(Some(Value::from("abc"))) })
683            }));
684
685        let query = Object::new("Query").field(Field::new(
686            "valueObj",
687            TypeRef::named_nn(myobj.type_name()),
688            |_| FieldFuture::new(async { Ok(Some(FieldValue::NULL)) }),
689        ));
690        let schema = Schema::build("Query", None, None)
691            .register(query)
692            .register(myobj)
693            .finish()
694            .unwrap();
695
696        let query = r#"
697            {
698                valueObj {
699                     ... on MyObj { a }
700                     ... { b }
701                }
702            }
703            "#;
704
705        assert_eq!(
706            schema.execute(query).await.into_result().unwrap().data,
707            value!({
708                "valueObj": {
709                    "a": 123,
710                    "b": "abc",
711                }
712            })
713        );
714    }
715
716    #[tokio::test]
717    async fn non_null() {
718        let query = Object::new("Query")
719            .field(Field::new(
720                "valueA",
721                TypeRef::named_nn(TypeRef::INT),
722                |_| FieldFuture::new(async { Ok(FieldValue::none()) }),
723            ))
724            .field(Field::new(
725                "valueB",
726                TypeRef::named_nn(TypeRef::INT),
727                |_| FieldFuture::new(async { Ok(Some(Value::from(100))) }),
728            ))
729            .field(Field::new("valueC", TypeRef::named(TypeRef::INT), |_| {
730                FieldFuture::new(async { Ok(FieldValue::none()) })
731            }))
732            .field(Field::new("valueD", TypeRef::named(TypeRef::INT), |_| {
733                FieldFuture::new(async { Ok(Some(Value::from(200))) })
734            }));
735        let schema = Schema::build("Query", None, None)
736            .register(query)
737            .finish()
738            .unwrap();
739
740        assert_eq!(
741            schema
742                .execute("{ valueA }")
743                .await
744                .into_result()
745                .unwrap_err(),
746            vec![ServerError {
747                message: "internal: non-null types require a return value".to_owned(),
748                source: None,
749                locations: vec![Pos { column: 3, line: 1 }],
750                path: vec![PathSegment::Field("valueA".to_owned())],
751                extensions: None,
752            }]
753        );
754
755        assert_eq!(
756            schema
757                .execute("{ valueB }")
758                .await
759                .into_result()
760                .unwrap()
761                .data,
762            value!({
763                "valueB": 100
764            })
765        );
766
767        assert_eq!(
768            schema
769                .execute("{ valueC valueD }")
770                .await
771                .into_result()
772                .unwrap()
773                .data,
774            value!({
775                "valueC": null,
776                "valueD": 200,
777            })
778        );
779    }
780
781    #[tokio::test]
782    async fn list() {
783        let query = Object::new("Query")
784            .field(Field::new(
785                "values",
786                TypeRef::named_nn_list_nn(TypeRef::INT),
787                |_| {
788                    FieldFuture::new(async {
789                        Ok(Some(vec![Value::from(3), Value::from(6), Value::from(9)]))
790                    })
791                },
792            ))
793            .field(Field::new(
794                "values2",
795                TypeRef::named_nn_list_nn(TypeRef::INT),
796                |_| {
797                    FieldFuture::new(async {
798                        Ok(Some(Value::List(vec![
799                            Value::from(3),
800                            Value::from(6),
801                            Value::from(9),
802                        ])))
803                    })
804                },
805            ))
806            .field(Field::new(
807                "values3",
808                TypeRef::named_nn_list(TypeRef::INT),
809                |_| FieldFuture::new(async { Ok(None::<Vec<Value>>) }),
810            ));
811        let schema = Schema::build("Query", None, None)
812            .register(query)
813            .finish()
814            .unwrap();
815
816        assert_eq!(
817            schema
818                .execute("{ values values2 values3 }")
819                .await
820                .into_result()
821                .unwrap()
822                .data,
823            value!({
824                "values": [3, 6, 9],
825                "values2": [3, 6, 9],
826                "values3": null,
827            })
828        );
829    }
830
831    #[tokio::test]
832    async fn extensions() {
833        struct MyExtensionImpl {
834            calls: Arc<Mutex<Vec<&'static str>>>,
835        }
836
837        #[async_trait::async_trait]
838        #[allow(unused_variables)]
839        impl Extension for MyExtensionImpl {
840            async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response {
841                self.calls.lock().await.push("request_start");
842                let res = next.run(ctx).await;
843                self.calls.lock().await.push("request_end");
844                res
845            }
846
847            fn subscribe<'s>(
848                &self,
849                ctx: &ExtensionContext<'_>,
850                mut stream: BoxStream<'s, Response>,
851                next: NextSubscribe<'_>,
852            ) -> BoxStream<'s, Response> {
853                let calls = self.calls.clone();
854                next.run(
855                    ctx,
856                    Box::pin(async_stream::stream! {
857                        calls.lock().await.push("subscribe_start");
858                        while let Some(item) = stream.next().await {
859                            yield item;
860                        }
861                        calls.lock().await.push("subscribe_end");
862                    }),
863                )
864            }
865
866            async fn prepare_request(
867                &self,
868                ctx: &ExtensionContext<'_>,
869                request: Request,
870                next: NextPrepareRequest<'_>,
871            ) -> ServerResult<Request> {
872                self.calls.lock().await.push("prepare_request_start");
873                let res = next.run(ctx, request).await;
874                self.calls.lock().await.push("prepare_request_end");
875                res
876            }
877
878            async fn parse_query(
879                &self,
880                ctx: &ExtensionContext<'_>,
881                query: &str,
882                variables: &Variables,
883                next: NextParseQuery<'_>,
884            ) -> ServerResult<ExecutableDocument> {
885                self.calls.lock().await.push("parse_query_start");
886                let res = next.run(ctx, query, variables).await;
887                self.calls.lock().await.push("parse_query_end");
888                res
889            }
890
891            async fn validation(
892                &self,
893                ctx: &ExtensionContext<'_>,
894                next: NextValidation<'_>,
895            ) -> Result<ValidationResult, Vec<ServerError>> {
896                self.calls.lock().await.push("validation_start");
897                let res = next.run(ctx).await;
898                self.calls.lock().await.push("validation_end");
899                res
900            }
901
902            async fn execute(
903                &self,
904                ctx: &ExtensionContext<'_>,
905                operation_name: Option<&str>,
906                next: NextExecute<'_>,
907            ) -> Response {
908                assert_eq!(operation_name, Some("Abc"));
909                self.calls.lock().await.push("execute_start");
910                let res = next.run(ctx, operation_name).await;
911                self.calls.lock().await.push("execute_end");
912                res
913            }
914
915            async fn resolve(
916                &self,
917                ctx: &ExtensionContext<'_>,
918                info: ResolveInfo<'_>,
919                next: NextResolve<'_>,
920            ) -> ServerResult<Option<Value>> {
921                self.calls.lock().await.push("resolve_start");
922                let res = next.run(ctx, info).await;
923                self.calls.lock().await.push("resolve_end");
924                res
925            }
926        }
927
928        struct MyExtension {
929            calls: Arc<Mutex<Vec<&'static str>>>,
930        }
931
932        impl ExtensionFactory for MyExtension {
933            fn create(&self) -> Arc<dyn Extension> {
934                Arc::new(MyExtensionImpl {
935                    calls: self.calls.clone(),
936                })
937            }
938        }
939
940        {
941            let query = Object::new("Query")
942                .field(Field::new(
943                    "value1",
944                    TypeRef::named_nn(TypeRef::INT),
945                    |_| FieldFuture::new(async { Ok(Some(Value::from(10))) }),
946                ))
947                .field(Field::new(
948                    "value2",
949                    TypeRef::named_nn(TypeRef::INT),
950                    |_| FieldFuture::new(async { Ok(Some(Value::from(10))) }),
951                ));
952
953            let calls: Arc<Mutex<Vec<&'static str>>> = Default::default();
954            let schema = Schema::build(query.type_name(), None, None)
955                .register(query)
956                .extension(MyExtension {
957                    calls: calls.clone(),
958                })
959                .finish()
960                .unwrap();
961
962            let _ = schema
963                .execute("query Abc { value1 value2 }")
964                .await
965                .into_result()
966                .unwrap();
967            let calls = calls.lock().await;
968            assert_eq!(
969                &*calls,
970                &vec![
971                    "request_start",
972                    "prepare_request_start",
973                    "prepare_request_end",
974                    "parse_query_start",
975                    "parse_query_end",
976                    "validation_start",
977                    "validation_end",
978                    "execute_start",
979                    "resolve_start",
980                    "resolve_end",
981                    "resolve_start",
982                    "resolve_end",
983                    "execute_end",
984                    "request_end",
985                ]
986            );
987        }
988
989        {
990            let query = Object::new("Query").field(Field::new(
991                "value1",
992                TypeRef::named_nn(TypeRef::INT),
993                |_| FieldFuture::new(async { Ok(Some(Value::from(10))) }),
994            ));
995
996            let subscription = Subscription::new("Subscription").field(SubscriptionField::new(
997                "value",
998                TypeRef::named_nn(TypeRef::INT),
999                |_| {
1000                    SubscriptionFieldFuture::new(async {
1001                        Ok(futures_util::stream::iter([1, 2, 3])
1002                            .map(|value| Ok(Value::from(value))))
1003                    })
1004                },
1005            ));
1006
1007            let calls: Arc<Mutex<Vec<&'static str>>> = Default::default();
1008            let schema = Schema::build(query.type_name(), None, Some(subscription.type_name()))
1009                .register(query)
1010                .register(subscription)
1011                .extension(MyExtension {
1012                    calls: calls.clone(),
1013                })
1014                .finish()
1015                .unwrap();
1016
1017            let mut stream = schema.execute_stream("subscription Abc { value }");
1018            while stream.next().await.is_some() {}
1019            let calls = calls.lock().await;
1020            assert_eq!(
1021                &*calls,
1022                &vec![
1023                    "subscribe_start",
1024                    "prepare_request_start",
1025                    "prepare_request_end",
1026                    "parse_query_start",
1027                    "parse_query_end",
1028                    "validation_start",
1029                    "validation_end",
1030                    // push 1
1031                    "execute_start",
1032                    "resolve_start",
1033                    "resolve_end",
1034                    "execute_end",
1035                    // push 2
1036                    "execute_start",
1037                    "resolve_start",
1038                    "resolve_end",
1039                    "execute_end",
1040                    // push 3
1041                    "execute_start",
1042                    "resolve_start",
1043                    "resolve_end",
1044                    "execute_end",
1045                    // end
1046                    "subscribe_end",
1047                ]
1048            );
1049        }
1050    }
1051
1052    #[tokio::test]
1053    async fn federation() {
1054        let user = Object::new("User")
1055            .field(Field::new(
1056                "name",
1057                TypeRef::named_nn(TypeRef::STRING),
1058                |_| FieldFuture::new(async { Ok(Some(FieldValue::value("test"))) }),
1059            ))
1060            .key("name");
1061
1062        let query =
1063            Object::new("Query").field(Field::new("value", TypeRef::named(TypeRef::INT), |_| {
1064                FieldFuture::new(async { Ok(Some(Value::from(100))) })
1065            }));
1066
1067        let schema = Schema::build("Query", None, None)
1068            .register(query)
1069            .register(user)
1070            .entity_resolver(|ctx| {
1071                FieldFuture::new(async move {
1072                    let representations = ctx.args.try_get("representations")?.list()?;
1073                    let mut values = Vec::new();
1074
1075                    for item in representations.iter() {
1076                        let item = item.object()?;
1077                        let typename = item
1078                            .try_get("__typename")
1079                            .and_then(|value| value.string())?;
1080
1081                        if typename == "User" {
1082                            values.push(FieldValue::borrowed_any(&()).with_type("User"));
1083                        }
1084                    }
1085
1086                    Ok(Some(FieldValue::list(values)))
1087                })
1088            })
1089            .finish()
1090            .unwrap();
1091
1092        assert_eq!(
1093            schema
1094                .execute(
1095                    r#"
1096                {
1097                    _entities(representations: [{__typename: "User", name: "test"}]) {
1098                        __typename
1099                        ... on User {
1100                            name
1101                        }
1102                    }
1103                }
1104                "#
1105                )
1106                .await
1107                .into_result()
1108                .unwrap()
1109                .data,
1110            value!({
1111                "_entities": [{
1112                    "__typename": "User",
1113                    "name": "test",
1114                }],
1115            })
1116        );
1117    }
1118}