Skip to main content

async_graphql/dynamic/
schema.rs

1use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc};
2
3use async_graphql_parser::types::OperationType;
4use futures_util::{StreamExt, TryFutureExt, stream::BoxStream};
5use indexmap::IndexMap;
6
7use crate::{
8    Data, Executor, IntrospectionMode, QueryEnv, Request, Response, SDLExportOptions, SchemaEnv,
9    ServerError, ServerResult, ValidationMode,
10    dynamic::{
11        DynamicRequest, FieldFuture, FieldValue, Object, ResolverContext, Scalar, SchemaError,
12        Subscription, TypeRef, Union, field::BoxResolverFn, resolve::resolve_container,
13        r#type::Type,
14    },
15    extensions::{ExtensionFactory, Extensions},
16    registry::{MetaType, Registry},
17    schema::{SchemaEnvInner, prepare_request},
18};
19
20/// Dynamic schema builder
21pub struct SchemaBuilder {
22    query_type: String,
23    mutation_type: Option<String>,
24    subscription_type: Option<String>,
25    types: IndexMap<String, Type>,
26    data: Data,
27    extensions: Vec<Box<dyn ExtensionFactory>>,
28    validation_mode: ValidationMode,
29    recursive_depth: usize,
30    max_directives: Option<usize>,
31    max_aliases: Option<usize>,
32    complexity: Option<usize>,
33    depth: Option<usize>,
34    enable_suggestions: bool,
35    introspection_mode: IntrospectionMode,
36    enable_federation: bool,
37    entity_resolver: Option<BoxResolverFn>,
38}
39
40impl SchemaBuilder {
41    /// Register a GraphQL type
42    #[must_use]
43    pub fn register(mut self, ty: impl Into<Type>) -> Self {
44        let ty = ty.into();
45        self.types.insert(ty.name().to_string(), ty);
46        self
47    }
48
49    /// Enable uploading files (register Upload type).
50    pub fn enable_uploading(mut self) -> Self {
51        self.types.insert(TypeRef::UPLOAD.to_string(), Type::Upload);
52        self
53    }
54
55    /// Add a global data that can be accessed in the `Schema`. You access it
56    /// with `Context::data`.
57    #[must_use]
58    pub fn data<D: Any + Send + Sync>(mut self, data: D) -> Self {
59        self.data.insert(data);
60        self
61    }
62
63    /// Add an extension to the schema.
64    #[must_use]
65    pub fn extension(mut self, extension: impl ExtensionFactory) -> Self {
66        self.extensions.push(Box::new(extension));
67        self
68    }
69
70    /// Set the maximum complexity a query can have. By default, there is no
71    /// limit.
72    #[must_use]
73    pub fn limit_complexity(mut self, complexity: usize) -> Self {
74        self.complexity = Some(complexity);
75        self
76    }
77
78    /// Set the maximum depth a query can have. By default, there is no limit.
79    #[must_use]
80    pub fn limit_depth(mut self, depth: usize) -> Self {
81        self.depth = Some(depth);
82        self
83    }
84
85    /// Set the maximum recursive depth a query can have. (default: 32)
86    ///
87    /// If the value is too large, stack overflow may occur, usually `32` is
88    /// enough.
89    #[must_use]
90    pub fn limit_recursive_depth(mut self, depth: usize) -> Self {
91        self.recursive_depth = depth;
92        self
93    }
94
95    /// Set the maximum number of directives on a single field. (default: no
96    /// limit)
97    pub fn limit_directives(mut self, max_directives: usize) -> Self {
98        self.max_directives = Some(max_directives);
99        self
100    }
101
102    /// Set the maximum number of aliases on a single query. (default: no
103    /// limit)
104    pub fn limit_aliases(mut self, max_aliases: usize) -> Self {
105        self.max_aliases = Some(max_aliases);
106        self
107    }
108
109    /// Set the validation mode, default is `ValidationMode::Strict`.
110    #[must_use]
111    pub fn validation_mode(mut self, validation_mode: ValidationMode) -> Self {
112        self.validation_mode = validation_mode;
113        self
114    }
115
116    /// Disable field suggestions.
117    #[must_use]
118    pub fn disable_suggestions(mut self) -> Self {
119        self.enable_suggestions = false;
120        self
121    }
122
123    /// Disable introspection queries.
124    #[must_use]
125    pub fn disable_introspection(mut self) -> Self {
126        self.introspection_mode = IntrospectionMode::Disabled;
127        self
128    }
129
130    /// Only process introspection queries, everything else is processed as an
131    /// error.
132    #[must_use]
133    pub fn introspection_only(mut self) -> Self {
134        self.introspection_mode = IntrospectionMode::IntrospectionOnly;
135        self
136    }
137
138    /// Enable federation, which is automatically enabled if the Query has least
139    /// one entity definition.
140    #[must_use]
141    pub fn enable_federation(mut self) -> Self {
142        self.enable_federation = true;
143        self
144    }
145
146    /// Set the entity resolver for federation
147    pub fn entity_resolver<F>(self, resolver_fn: F) -> Self
148    where
149        F: for<'a> Fn(ResolverContext<'a>) -> FieldFuture<'a> + Send + Sync + 'static,
150    {
151        Self {
152            entity_resolver: Some(Box::new(resolver_fn)),
153            ..self
154        }
155    }
156
157    /// Consumes this builder and returns a schema.
158    pub fn finish(mut self) -> Result<Schema, SchemaError> {
159        let mut registry = Registry {
160            types: Default::default(),
161            directives: Default::default(),
162            implements: Default::default(),
163            query_type: self.query_type,
164            mutation_type: self.mutation_type,
165            subscription_type: self.subscription_type,
166            introspection_mode: self.introspection_mode,
167            enable_federation: false,
168            federation_subscription: false,
169            ignore_name_conflicts: Default::default(),
170            enable_suggestions: self.enable_suggestions,
171        };
172        registry.add_system_types();
173
174        for ty in self.types.values() {
175            ty.register(&mut registry)?;
176        }
177        update_interface_possible_types(&mut self.types, &mut registry);
178
179        // create system scalars
180        for ty in ["Int", "Float", "Boolean", "String", "ID"] {
181            self.types
182                .insert(ty.to_string(), Type::Scalar(Scalar::new(ty)));
183        }
184
185        // create introspection types
186        if matches!(
187            self.introspection_mode,
188            IntrospectionMode::Enabled | IntrospectionMode::IntrospectionOnly
189        ) {
190            registry.create_introspection_types();
191        }
192
193        // create entity types
194        if self.enable_federation || registry.has_entities() {
195            registry.enable_federation = true;
196            registry.create_federation_types();
197
198            // create _Entity type
199            let entity = self
200                .types
201                .values()
202                .filter(|ty| match ty {
203                    Type::Object(obj) => obj.is_entity(),
204                    Type::Interface(interface) => interface.is_entity(),
205                    _ => false,
206                })
207                .fold(Union::new("_Entity"), |entity, ty| {
208                    entity.possible_type(ty.name())
209                });
210            self.types
211                .insert("_Entity".to_string(), Type::Union(entity));
212        }
213
214        let inner = SchemaInner {
215            env: SchemaEnv(Arc::new(SchemaEnvInner {
216                registry,
217                data: self.data,
218                custom_directives: Default::default(),
219            })),
220            extensions: self.extensions,
221            types: self.types,
222            recursive_depth: self.recursive_depth,
223            max_directives: self.max_directives,
224            max_aliases: self.max_aliases,
225            complexity: self.complexity,
226            depth: self.depth,
227            validation_mode: self.validation_mode,
228            entity_resolver: self.entity_resolver,
229        };
230        inner.check()?;
231        Ok(Schema(Arc::new(inner)))
232    }
233}
234
235/// Dynamic GraphQL schema.
236///
237/// Cloning a schema is cheap, so it can be easily shared.
238#[derive(Clone)]
239pub struct Schema(pub(crate) Arc<SchemaInner>);
240
241impl Debug for Schema {
242    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243        f.debug_struct("Schema").finish()
244    }
245}
246
247pub struct SchemaInner {
248    pub(crate) env: SchemaEnv,
249    pub(crate) types: IndexMap<String, Type>,
250    extensions: Vec<Box<dyn ExtensionFactory>>,
251    recursive_depth: usize,
252    max_directives: Option<usize>,
253    max_aliases: Option<usize>,
254    complexity: Option<usize>,
255    depth: Option<usize>,
256    validation_mode: ValidationMode,
257    pub(crate) entity_resolver: Option<BoxResolverFn>,
258}
259
260impl Schema {
261    /// Create a schema builder
262    pub fn build(query: &str, mutation: Option<&str>, subscription: Option<&str>) -> SchemaBuilder {
263        SchemaBuilder {
264            query_type: query.to_string(),
265            mutation_type: mutation.map(ToString::to_string),
266            subscription_type: subscription.map(ToString::to_string),
267            types: Default::default(),
268            data: Default::default(),
269            extensions: Default::default(),
270            validation_mode: ValidationMode::Strict,
271            recursive_depth: 32,
272            max_directives: None,
273            max_aliases: None,
274            complexity: None,
275            depth: None,
276            enable_suggestions: true,
277            introspection_mode: IntrospectionMode::Enabled,
278            entity_resolver: None,
279            enable_federation: false,
280        }
281    }
282
283    fn create_extensions(&self, session_data: Arc<Data>) -> Extensions {
284        Extensions::new(
285            self.0.extensions.iter().map(|f| f.create()),
286            self.0.env.clone(),
287            session_data,
288        )
289    }
290
291    fn query_root(&self) -> ServerResult<&Object> {
292        self.0
293            .types
294            .get(&self.0.env.registry.query_type)
295            .and_then(Type::as_object)
296            .ok_or_else(|| ServerError::new("Query root not found", None))
297    }
298
299    fn mutation_root(&self) -> ServerResult<&Object> {
300        self.0
301            .env
302            .registry
303            .mutation_type
304            .as_ref()
305            .and_then(|mutation_name| self.0.types.get(mutation_name))
306            .and_then(Type::as_object)
307            .ok_or_else(|| ServerError::new("Mutation root not found", None))
308    }
309
310    fn subscription_root(&self) -> ServerResult<&Subscription> {
311        self.0
312            .env
313            .registry
314            .subscription_type
315            .as_ref()
316            .and_then(|subscription_name| self.0.types.get(subscription_name))
317            .and_then(Type::as_subscription)
318            .ok_or_else(|| ServerError::new("Subscription root not found", None))
319    }
320
321    /// Returns SDL(Schema Definition Language) of this schema.
322    pub fn sdl(&self) -> String {
323        self.0.env.registry.export_sdl(Default::default())
324    }
325
326    /// Returns SDL(Schema Definition Language) of this schema with options.
327    pub fn sdl_with_options(&self, options: SDLExportOptions) -> String {
328        self.0.env.registry.export_sdl(options)
329    }
330
331    async fn execute_once(
332        &self,
333        env: QueryEnv,
334        root_value: &FieldValue<'static>,
335        execute_data: Option<Data>,
336    ) -> Response {
337        // execute
338        let ctx = env.create_context(
339            &self.0.env,
340            None,
341            &env.operation.node.selection_set,
342            execute_data.as_ref(),
343        );
344        let res = match &env.operation.node.ty {
345            OperationType::Query => {
346                async move { self.query_root() }
347                    .and_then(|query_root| {
348                        resolve_container(self, query_root, &ctx, root_value, false)
349                    })
350                    .await
351            }
352            OperationType::Mutation => {
353                async move { self.mutation_root() }
354                    .and_then(|query_root| {
355                        resolve_container(self, query_root, &ctx, root_value, true)
356                    })
357                    .await
358            }
359            OperationType::Subscription => Err(ServerError::new(
360                "Subscriptions are not supported on this transport.",
361                None,
362            )),
363        };
364
365        let mut resp = match res {
366            Ok(value) => Response::new(value.unwrap_or_default()),
367            Err(err) => Response::from_errors(vec![err]),
368        }
369        .http_headers(std::mem::take(&mut *env.http_headers.lock().unwrap()));
370
371        resp.errors
372            .extend(std::mem::take(&mut *env.errors.lock().unwrap()));
373        resp
374    }
375
376    /// Execute a GraphQL query.
377    pub async fn execute(&self, request: impl Into<DynamicRequest>) -> Response {
378        let request = request.into();
379        let extensions = self.create_extensions(Default::default());
380        let request_fut = {
381            let extensions = extensions.clone();
382            async move {
383                match prepare_request(
384                    extensions,
385                    request.inner,
386                    Default::default(),
387                    &self.0.env.registry,
388                    self.0.validation_mode,
389                    self.0.recursive_depth,
390                    self.0.max_directives,
391                    self.0.max_aliases,
392                    self.0.complexity,
393                    self.0.depth,
394                )
395                .await
396                {
397                    Ok((env, cache_control)) => {
398                        let f = {
399                            |execute_data| {
400                                let env = env.clone();
401                                async move {
402                                    self.execute_once(env, &request.root_value, execute_data)
403                                        .await
404                                        .cache_control(cache_control)
405                                }
406                            }
407                        };
408                        env.extensions
409                            .execute(env.operation_name.as_deref(), f)
410                            .await
411                    }
412                    Err(errors) => Response::from_errors(errors),
413                }
414            }
415        };
416        futures_util::pin_mut!(request_fut);
417        extensions.request(&mut request_fut).await
418    }
419
420    /// Execute a GraphQL subscription with session data.
421    pub fn execute_stream_with_session_data(
422        &self,
423        request: impl Into<DynamicRequest>,
424        session_data: Arc<Data>,
425    ) -> BoxStream<'static, Response> {
426        let schema = self.clone();
427        let request = request.into();
428        let extensions = self.create_extensions(session_data.clone());
429
430        let stream = {
431            let extensions = extensions.clone();
432
433            asynk_strim::stream_fn(|mut yielder| async move {
434                let subscription = match schema.subscription_root() {
435                    Ok(subscription) => subscription,
436                    Err(err) => {
437                        yielder.yield_item(Response::from_errors(vec![err])).await;
438                        return;
439                    }
440                };
441
442                let (env, _) = match prepare_request(
443                    extensions,
444                    request.inner,
445                    session_data,
446                    &schema.0.env.registry,
447                    schema.0.validation_mode,
448                    schema.0.recursive_depth,
449                    schema.0.max_directives,
450                    schema.0.max_aliases,
451                    schema.0.complexity,
452                    schema.0.depth,
453                )
454                .await
455                {
456                    Ok(res) => res,
457                    Err(errors) => {
458                        yielder.yield_item(Response::from_errors(errors)).await;
459                        return;
460                    }
461                };
462
463                if env.operation.node.ty != OperationType::Subscription {
464                    yielder
465                        .yield_item(schema.execute_once(env, &request.root_value, None).await)
466                        .await;
467                    return;
468                }
469
470                let ctx = env.create_context(
471                    &schema.0.env,
472                    None,
473                    &env.operation.node.selection_set,
474                    None,
475                );
476                let mut streams = Vec::new();
477                subscription.collect_streams(&schema, &ctx, &mut streams, &request.root_value);
478
479                let mut stream = futures_util::stream::select_all(streams);
480                while let Some(resp) = stream.next().await {
481                    yielder.yield_item(resp).await;
482                }
483            })
484        };
485        extensions.subscribe(stream.boxed())
486    }
487
488    /// Execute a GraphQL subscription.
489    pub fn execute_stream(
490        &self,
491        request: impl Into<DynamicRequest>,
492    ) -> BoxStream<'static, Response> {
493        self.execute_stream_with_session_data(request, Default::default())
494    }
495
496    /// Returns the registry of this schema.
497    pub fn registry(&self) -> &Registry {
498        &self.0.env.registry
499    }
500}
501
502#[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
503impl Executor for Schema {
504    async fn execute(&self, request: Request) -> Response {
505        Schema::execute(self, request).await
506    }
507
508    fn execute_stream(
509        &self,
510        request: Request,
511        session_data: Option<Arc<Data>>,
512    ) -> BoxStream<'static, Response> {
513        Schema::execute_stream_with_session_data(self, request, session_data.unwrap_or_default())
514    }
515}
516
517fn update_interface_possible_types(types: &mut IndexMap<String, Type>, registry: &mut Registry) {
518    let mut interfaces = registry
519        .types
520        .values_mut()
521        .filter_map(|ty| match ty {
522            MetaType::Interface {
523                name,
524                possible_types,
525                ..
526            } => Some((name, possible_types)),
527            _ => None,
528        })
529        .collect::<HashMap<_, _>>();
530
531    let objs = types.values().filter_map(|ty| match ty {
532        Type::Object(obj) => Some((&obj.name, &obj.implements)),
533        _ => None,
534    });
535
536    for (obj_name, implements) in objs {
537        for interface in implements {
538            if let Some(possible_types) = interfaces.get_mut(interface) {
539                possible_types.insert(obj_name.clone());
540            }
541        }
542    }
543}
544
545#[cfg(test)]
546mod tests {
547    use std::sync::Arc;
548
549    use async_graphql_parser::{Pos, types::ExecutableDocument};
550    use async_graphql_value::Variables;
551    use futures_util::{StreamExt, stream::BoxStream};
552    use tokio::sync::Mutex;
553
554    use crate::{
555        PathSegment, Request, Response, ServerError, ServerResult, ValidationResult, Value,
556        dynamic::*, extensions::*, value,
557    };
558
559    #[tokio::test]
560    async fn basic_query() {
561        let myobj = Object::new("MyObj")
562            .field(Field::new("a", TypeRef::named(TypeRef::INT), |_| {
563                FieldFuture::new(async { Ok(Some(Value::from(123))) })
564            }))
565            .field(Field::new("b", TypeRef::named(TypeRef::STRING), |_| {
566                FieldFuture::new(async { Ok(Some(Value::from("abc"))) })
567            }));
568
569        let query = Object::new("Query")
570            .field(Field::new("value", TypeRef::named(TypeRef::INT), |_| {
571                FieldFuture::new(async { Ok(Some(Value::from(100))) })
572            }))
573            .field(Field::new(
574                "valueObj",
575                TypeRef::named_nn(myobj.type_name()),
576                |_| FieldFuture::new(async { Ok(Some(FieldValue::NULL)) }),
577            ));
578        let schema = Schema::build("Query", None, None)
579            .register(query)
580            .register(myobj)
581            .finish()
582            .unwrap();
583
584        assert_eq!(
585            schema
586                .execute("{ value valueObj { a b } }")
587                .await
588                .into_result()
589                .unwrap()
590                .data,
591            value!({
592                "value": 100,
593                "valueObj": {
594                    "a": 123,
595                    "b": "abc",
596                }
597            })
598        );
599    }
600
601    #[tokio::test]
602    async fn root_value() {
603        let query =
604            Object::new("Query").field(Field::new("value", TypeRef::named(TypeRef::INT), |ctx| {
605                FieldFuture::new(async {
606                    Ok(Some(Value::Number(
607                        (*ctx.parent_value.try_downcast_ref::<i32>()?).into(),
608                    )))
609                })
610            }));
611
612        let schema = Schema::build("Query", None, None)
613            .register(query)
614            .finish()
615            .unwrap();
616        assert_eq!(
617            schema
618                .execute("{ value }".root_value(FieldValue::owned_any(100)))
619                .await
620                .into_result()
621                .unwrap()
622                .data,
623            value!({ "value": 100, })
624        );
625    }
626
627    #[tokio::test]
628    async fn field_alias() {
629        let query =
630            Object::new("Query").field(Field::new("value", TypeRef::named(TypeRef::INT), |_| {
631                FieldFuture::new(async { Ok(Some(Value::from(100))) })
632            }));
633        let schema = Schema::build("Query", None, None)
634            .register(query)
635            .finish()
636            .unwrap();
637
638        assert_eq!(
639            schema
640                .execute("{ a: value }")
641                .await
642                .into_result()
643                .unwrap()
644                .data,
645            value!({
646                "a": 100,
647            })
648        );
649    }
650
651    #[tokio::test]
652    async fn fragment_spread() {
653        let myobj = Object::new("MyObj")
654            .field(Field::new("a", TypeRef::named(TypeRef::INT), |_| {
655                FieldFuture::new(async { Ok(Some(Value::from(123))) })
656            }))
657            .field(Field::new("b", TypeRef::named(TypeRef::STRING), |_| {
658                FieldFuture::new(async { Ok(Some(Value::from("abc"))) })
659            }));
660
661        let query = Object::new("Query").field(Field::new(
662            "valueObj",
663            TypeRef::named_nn(myobj.type_name()),
664            |_| FieldFuture::new(async { Ok(Some(Value::Null)) }),
665        ));
666        let schema = Schema::build("Query", None, None)
667            .register(query)
668            .register(myobj)
669            .finish()
670            .unwrap();
671
672        let query = r#"
673            fragment A on MyObj {
674                a b
675            }
676
677            { valueObj { ... A } }
678            "#;
679
680        assert_eq!(
681            schema.execute(query).await.into_result().unwrap().data,
682            value!({
683                "valueObj": {
684                    "a": 123,
685                    "b": "abc",
686                }
687            })
688        );
689    }
690
691    #[tokio::test]
692    async fn inline_fragment() {
693        let myobj = Object::new("MyObj")
694            .field(Field::new("a", TypeRef::named(TypeRef::INT), |_| {
695                FieldFuture::new(async { Ok(Some(Value::from(123))) })
696            }))
697            .field(Field::new("b", TypeRef::named(TypeRef::STRING), |_| {
698                FieldFuture::new(async { Ok(Some(Value::from("abc"))) })
699            }));
700
701        let query = Object::new("Query").field(Field::new(
702            "valueObj",
703            TypeRef::named_nn(myobj.type_name()),
704            |_| FieldFuture::new(async { Ok(Some(FieldValue::NULL)) }),
705        ));
706        let schema = Schema::build("Query", None, None)
707            .register(query)
708            .register(myobj)
709            .finish()
710            .unwrap();
711
712        let query = r#"
713            {
714                valueObj {
715                     ... on MyObj { a }
716                     ... { b }
717                }
718            }
719            "#;
720
721        assert_eq!(
722            schema.execute(query).await.into_result().unwrap().data,
723            value!({
724                "valueObj": {
725                    "a": 123,
726                    "b": "abc",
727                }
728            })
729        );
730    }
731
732    #[tokio::test]
733    async fn non_null() {
734        let query = Object::new("Query")
735            .field(Field::new(
736                "valueA",
737                TypeRef::named_nn(TypeRef::INT),
738                |_| FieldFuture::new(async { Ok(FieldValue::none()) }),
739            ))
740            .field(Field::new(
741                "valueB",
742                TypeRef::named_nn(TypeRef::INT),
743                |_| FieldFuture::new(async { Ok(Some(Value::from(100))) }),
744            ))
745            .field(Field::new("valueC", TypeRef::named(TypeRef::INT), |_| {
746                FieldFuture::new(async { Ok(FieldValue::none()) })
747            }))
748            .field(Field::new("valueD", TypeRef::named(TypeRef::INT), |_| {
749                FieldFuture::new(async { Ok(Some(Value::from(200))) })
750            }));
751        let schema = Schema::build("Query", None, None)
752            .register(query)
753            .finish()
754            .unwrap();
755
756        assert_eq!(
757            schema
758                .execute("{ valueA }")
759                .await
760                .into_result()
761                .unwrap_err(),
762            vec![ServerError {
763                message: "internal: non-null types require a return value".to_owned(),
764                source: None,
765                locations: vec![Pos { column: 3, line: 1 }],
766                path: vec![PathSegment::Field("valueA".to_owned())],
767                extensions: None,
768            }]
769        );
770
771        assert_eq!(
772            schema
773                .execute("{ valueB }")
774                .await
775                .into_result()
776                .unwrap()
777                .data,
778            value!({
779                "valueB": 100
780            })
781        );
782
783        assert_eq!(
784            schema
785                .execute("{ valueC valueD }")
786                .await
787                .into_result()
788                .unwrap()
789                .data,
790            value!({
791                "valueC": null,
792                "valueD": 200,
793            })
794        );
795    }
796
797    #[tokio::test]
798    async fn list() {
799        let query = Object::new("Query")
800            .field(Field::new(
801                "values",
802                TypeRef::named_nn_list_nn(TypeRef::INT),
803                |_| {
804                    FieldFuture::new(async {
805                        Ok(Some(vec![Value::from(3), Value::from(6), Value::from(9)]))
806                    })
807                },
808            ))
809            .field(Field::new(
810                "values2",
811                TypeRef::named_nn_list_nn(TypeRef::INT),
812                |_| {
813                    FieldFuture::new(async {
814                        Ok(Some(Value::List(vec![
815                            Value::from(3),
816                            Value::from(6),
817                            Value::from(9),
818                        ])))
819                    })
820                },
821            ))
822            .field(Field::new(
823                "values3",
824                TypeRef::named_nn_list(TypeRef::INT),
825                |_| FieldFuture::new(async { Ok(None::<Vec<Value>>) }),
826            ));
827        let schema = Schema::build("Query", None, None)
828            .register(query)
829            .finish()
830            .unwrap();
831
832        assert_eq!(
833            schema
834                .execute("{ values values2 values3 }")
835                .await
836                .into_result()
837                .unwrap()
838                .data,
839            value!({
840                "values": [3, 6, 9],
841                "values2": [3, 6, 9],
842                "values3": null,
843            })
844        );
845    }
846
847    #[tokio::test]
848    async fn extensions() {
849        struct MyExtensionImpl {
850            calls: Arc<Mutex<Vec<&'static str>>>,
851        }
852
853        #[async_trait::async_trait]
854        #[allow(unused_variables)]
855        impl Extension for MyExtensionImpl {
856            async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response {
857                self.calls.lock().await.push("request_start");
858                let res = next.run(ctx).await;
859                self.calls.lock().await.push("request_end");
860                res
861            }
862
863            fn subscribe<'s>(
864                &self,
865                ctx: &ExtensionContext<'_>,
866                mut stream: BoxStream<'s, Response>,
867                next: NextSubscribe<'_>,
868            ) -> BoxStream<'s, Response> {
869                let calls = self.calls.clone();
870                next.run(
871                    ctx,
872                    Box::pin(asynk_strim::stream_fn(|mut yielder| async move {
873                        calls.lock().await.push("subscribe_start");
874                        while let Some(item) = stream.next().await {
875                            yielder.yield_item(item).await;
876                        }
877                        calls.lock().await.push("subscribe_end");
878                    })),
879                )
880            }
881
882            async fn prepare_request(
883                &self,
884                ctx: &ExtensionContext<'_>,
885                request: Request,
886                next: NextPrepareRequest<'_>,
887            ) -> ServerResult<Request> {
888                self.calls.lock().await.push("prepare_request_start");
889                let res = next.run(ctx, request).await;
890                self.calls.lock().await.push("prepare_request_end");
891                res
892            }
893
894            async fn parse_query(
895                &self,
896                ctx: &ExtensionContext<'_>,
897                query: &str,
898                variables: &Variables,
899                next: NextParseQuery<'_>,
900            ) -> ServerResult<ExecutableDocument> {
901                self.calls.lock().await.push("parse_query_start");
902                let res = next.run(ctx, query, variables).await;
903                self.calls.lock().await.push("parse_query_end");
904                res
905            }
906
907            async fn validation(
908                &self,
909                ctx: &ExtensionContext<'_>,
910                next: NextValidation<'_>,
911            ) -> Result<ValidationResult, Vec<ServerError>> {
912                self.calls.lock().await.push("validation_start");
913                let res = next.run(ctx).await;
914                self.calls.lock().await.push("validation_end");
915                res
916            }
917
918            async fn execute(
919                &self,
920                ctx: &ExtensionContext<'_>,
921                operation_name: Option<&str>,
922                next: NextExecute<'_>,
923            ) -> Response {
924                assert_eq!(operation_name, Some("Abc"));
925                self.calls.lock().await.push("execute_start");
926                let res = next.run(ctx, operation_name).await;
927                self.calls.lock().await.push("execute_end");
928                res
929            }
930
931            async fn resolve(
932                &self,
933                ctx: &ExtensionContext<'_>,
934                info: ResolveInfo<'_>,
935                next: NextResolve<'_>,
936            ) -> ServerResult<Option<Value>> {
937                self.calls.lock().await.push("resolve_start");
938                let res = next.run(ctx, info).await;
939                self.calls.lock().await.push("resolve_end");
940                res
941            }
942        }
943
944        struct MyExtension {
945            calls: Arc<Mutex<Vec<&'static str>>>,
946        }
947
948        impl ExtensionFactory for MyExtension {
949            fn create(&self) -> Arc<dyn Extension> {
950                Arc::new(MyExtensionImpl {
951                    calls: self.calls.clone(),
952                })
953            }
954        }
955
956        {
957            let query = Object::new("Query")
958                .field(Field::new(
959                    "value1",
960                    TypeRef::named_nn(TypeRef::INT),
961                    |_| FieldFuture::new(async { Ok(Some(Value::from(10))) }),
962                ))
963                .field(Field::new(
964                    "value2",
965                    TypeRef::named_nn(TypeRef::INT),
966                    |_| FieldFuture::new(async { Ok(Some(Value::from(10))) }),
967                ));
968
969            let calls: Arc<Mutex<Vec<&'static str>>> = Default::default();
970            let schema = Schema::build(query.type_name(), None, None)
971                .register(query)
972                .extension(MyExtension {
973                    calls: calls.clone(),
974                })
975                .finish()
976                .unwrap();
977
978            let _ = schema
979                .execute("query Abc { value1 value2 }")
980                .await
981                .into_result()
982                .unwrap();
983            let calls = calls.lock().await;
984            assert_eq!(
985                &*calls,
986                &vec![
987                    "request_start",
988                    "prepare_request_start",
989                    "prepare_request_end",
990                    "parse_query_start",
991                    "parse_query_end",
992                    "validation_start",
993                    "validation_end",
994                    "execute_start",
995                    "resolve_start",
996                    "resolve_end",
997                    "resolve_start",
998                    "resolve_end",
999                    "execute_end",
1000                    "request_end",
1001                ]
1002            );
1003        }
1004
1005        {
1006            let query = Object::new("Query").field(Field::new(
1007                "value1",
1008                TypeRef::named_nn(TypeRef::INT),
1009                |_| FieldFuture::new(async { Ok(Some(Value::from(10))) }),
1010            ));
1011
1012            let subscription = Subscription::new("Subscription").field(SubscriptionField::new(
1013                "value",
1014                TypeRef::named_nn(TypeRef::INT),
1015                |_| {
1016                    SubscriptionFieldFuture::new(async {
1017                        Ok(futures_util::stream::iter([1, 2, 3])
1018                            .map(|value| Ok(Value::from(value))))
1019                    })
1020                },
1021            ));
1022
1023            let calls: Arc<Mutex<Vec<&'static str>>> = Default::default();
1024            let schema = Schema::build(query.type_name(), None, Some(subscription.type_name()))
1025                .register(query)
1026                .register(subscription)
1027                .extension(MyExtension {
1028                    calls: calls.clone(),
1029                })
1030                .finish()
1031                .unwrap();
1032
1033            let mut stream = schema.execute_stream("subscription Abc { value }");
1034            while stream.next().await.is_some() {}
1035            let calls = calls.lock().await;
1036            assert_eq!(
1037                &*calls,
1038                &vec![
1039                    "subscribe_start",
1040                    "prepare_request_start",
1041                    "prepare_request_end",
1042                    "parse_query_start",
1043                    "parse_query_end",
1044                    "validation_start",
1045                    "validation_end",
1046                    // push 1
1047                    "execute_start",
1048                    "resolve_start",
1049                    "resolve_end",
1050                    "execute_end",
1051                    // push 2
1052                    "execute_start",
1053                    "resolve_start",
1054                    "resolve_end",
1055                    "execute_end",
1056                    // push 3
1057                    "execute_start",
1058                    "resolve_start",
1059                    "resolve_end",
1060                    "execute_end",
1061                    // end
1062                    "subscribe_end",
1063                ]
1064            );
1065        }
1066    }
1067
1068    #[tokio::test]
1069    async fn federation() {
1070        let user = Object::new("User")
1071            .field(Field::new(
1072                "name",
1073                TypeRef::named_nn(TypeRef::STRING),
1074                |_| FieldFuture::new(async { Ok(Some(FieldValue::value("test"))) }),
1075            ))
1076            .key("name");
1077
1078        let query =
1079            Object::new("Query").field(Field::new("value", TypeRef::named(TypeRef::INT), |_| {
1080                FieldFuture::new(async { Ok(Some(Value::from(100))) })
1081            }));
1082
1083        let schema = Schema::build("Query", None, None)
1084            .register(query)
1085            .register(user)
1086            .entity_resolver(|ctx| {
1087                FieldFuture::new(async move {
1088                    let representations = ctx.args.try_get("representations")?.list()?;
1089                    let mut values = Vec::new();
1090
1091                    for item in representations.iter() {
1092                        let item = item.object()?;
1093                        let typename = item
1094                            .try_get("__typename")
1095                            .and_then(|value| value.string())?;
1096
1097                        if typename == "User" {
1098                            values.push(FieldValue::borrowed_any(&()).with_type("User"));
1099                        }
1100                    }
1101
1102                    Ok(Some(FieldValue::list(values)))
1103                })
1104            })
1105            .finish()
1106            .unwrap();
1107
1108        assert_eq!(
1109            schema
1110                .execute(
1111                    r#"
1112                {
1113                    _entities(representations: [{__typename: "User", name: "test"}]) {
1114                        __typename
1115                        ... on User {
1116                            name
1117                        }
1118                    }
1119                }
1120                "#
1121                )
1122                .await
1123                .into_result()
1124                .unwrap()
1125                .data,
1126            value!({
1127                "_entities": [{
1128                    "__typename": "User",
1129                    "name": "test",
1130                }],
1131            })
1132        );
1133    }
1134}