async_graphql/dynamic/
schema.rs

1use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc};
2
3use async_graphql_parser::types::OperationType;
4use futures_util::{StreamExt, TryFutureExt, stream::BoxStream};
5use indexmap::IndexMap;
6
7use crate::{
8    Data, Executor, IntrospectionMode, QueryEnv, Request, Response, SDLExportOptions, SchemaEnv,
9    ServerError, ServerResult, ValidationMode,
10    dynamic::{
11        DynamicRequest, FieldFuture, FieldValue, Object, ResolverContext, Scalar, SchemaError,
12        Subscription, TypeRef, Union, field::BoxResolverFn, resolve::resolve_container,
13        r#type::Type,
14    },
15    extensions::{ExtensionFactory, Extensions},
16    registry::{MetaType, Registry},
17    schema::{SchemaEnvInner, prepare_request},
18};
19
20/// Dynamic schema builder
21pub struct SchemaBuilder {
22    query_type: String,
23    mutation_type: Option<String>,
24    subscription_type: Option<String>,
25    types: IndexMap<String, Type>,
26    data: Data,
27    extensions: Vec<Box<dyn ExtensionFactory>>,
28    validation_mode: ValidationMode,
29    recursive_depth: usize,
30    max_directives: Option<usize>,
31    complexity: Option<usize>,
32    depth: Option<usize>,
33    enable_suggestions: bool,
34    introspection_mode: IntrospectionMode,
35    enable_federation: bool,
36    entity_resolver: Option<BoxResolverFn>,
37}
38
39impl SchemaBuilder {
40    /// Register a GraphQL type
41    #[must_use]
42    pub fn register(mut self, ty: impl Into<Type>) -> Self {
43        let ty = ty.into();
44        self.types.insert(ty.name().to_string(), ty);
45        self
46    }
47
48    /// Enable uploading files (register Upload type).
49    pub fn enable_uploading(mut self) -> Self {
50        self.types.insert(TypeRef::UPLOAD.to_string(), Type::Upload);
51        self
52    }
53
54    /// Add a global data that can be accessed in the `Schema`. You access it
55    /// with `Context::data`.
56    #[must_use]
57    pub fn data<D: Any + Send + Sync>(mut self, data: D) -> Self {
58        self.data.insert(data);
59        self
60    }
61
62    /// Add an extension to the schema.
63    #[must_use]
64    pub fn extension(mut self, extension: impl ExtensionFactory) -> Self {
65        self.extensions.push(Box::new(extension));
66        self
67    }
68
69    /// Set the maximum complexity a query can have. By default, there is no
70    /// limit.
71    #[must_use]
72    pub fn limit_complexity(mut self, complexity: usize) -> Self {
73        self.complexity = Some(complexity);
74        self
75    }
76
77    /// Set the maximum depth a query can have. By default, there is no limit.
78    #[must_use]
79    pub fn limit_depth(mut self, depth: usize) -> Self {
80        self.depth = Some(depth);
81        self
82    }
83
84    /// Set the maximum recursive depth a query can have. (default: 32)
85    ///
86    /// If the value is too large, stack overflow may occur, usually `32` is
87    /// enough.
88    #[must_use]
89    pub fn limit_recursive_depth(mut self, depth: usize) -> Self {
90        self.recursive_depth = depth;
91        self
92    }
93
94    /// Set the maximum number of directives on a single field. (default: no
95    /// limit)
96    pub fn limit_directives(mut self, max_directives: usize) -> Self {
97        self.max_directives = Some(max_directives);
98        self
99    }
100
101    /// Set the validation mode, default is `ValidationMode::Strict`.
102    #[must_use]
103    pub fn validation_mode(mut self, validation_mode: ValidationMode) -> Self {
104        self.validation_mode = validation_mode;
105        self
106    }
107
108    /// Disable field suggestions.
109    #[must_use]
110    pub fn disable_suggestions(mut self) -> Self {
111        self.enable_suggestions = false;
112        self
113    }
114
115    /// Disable introspection queries.
116    #[must_use]
117    pub fn disable_introspection(mut self) -> Self {
118        self.introspection_mode = IntrospectionMode::Disabled;
119        self
120    }
121
122    /// Only process introspection queries, everything else is processed as an
123    /// error.
124    #[must_use]
125    pub fn introspection_only(mut self) -> Self {
126        self.introspection_mode = IntrospectionMode::IntrospectionOnly;
127        self
128    }
129
130    /// Enable federation, which is automatically enabled if the Query has least
131    /// one entity definition.
132    #[must_use]
133    pub fn enable_federation(mut self) -> Self {
134        self.enable_federation = true;
135        self
136    }
137
138    /// Set the entity resolver for federation
139    pub fn entity_resolver<F>(self, resolver_fn: F) -> Self
140    where
141        F: for<'a> Fn(ResolverContext<'a>) -> FieldFuture<'a> + Send + Sync + 'static,
142    {
143        Self {
144            entity_resolver: Some(Box::new(resolver_fn)),
145            ..self
146        }
147    }
148
149    /// Consumes this builder and returns a schema.
150    pub fn finish(mut self) -> Result<Schema, SchemaError> {
151        let mut registry = Registry {
152            types: Default::default(),
153            directives: Default::default(),
154            implements: Default::default(),
155            query_type: self.query_type,
156            mutation_type: self.mutation_type,
157            subscription_type: self.subscription_type,
158            introspection_mode: self.introspection_mode,
159            enable_federation: false,
160            federation_subscription: false,
161            ignore_name_conflicts: Default::default(),
162            enable_suggestions: self.enable_suggestions,
163        };
164        registry.add_system_types();
165
166        for ty in self.types.values() {
167            ty.register(&mut registry)?;
168        }
169        update_interface_possible_types(&mut self.types, &mut registry);
170
171        // create system scalars
172        for ty in ["Int", "Float", "Boolean", "String", "ID"] {
173            self.types
174                .insert(ty.to_string(), Type::Scalar(Scalar::new(ty)));
175        }
176
177        // create introspection types
178        if matches!(
179            self.introspection_mode,
180            IntrospectionMode::Enabled | IntrospectionMode::IntrospectionOnly
181        ) {
182            registry.create_introspection_types();
183        }
184
185        // create entity types
186        if self.enable_federation || registry.has_entities() {
187            registry.enable_federation = true;
188            registry.create_federation_types();
189
190            // create _Entity type
191            let entity = self
192                .types
193                .values()
194                .filter(|ty| match ty {
195                    Type::Object(obj) => obj.is_entity(),
196                    Type::Interface(interface) => interface.is_entity(),
197                    _ => false,
198                })
199                .fold(Union::new("_Entity"), |entity, ty| {
200                    entity.possible_type(ty.name())
201                });
202            self.types
203                .insert("_Entity".to_string(), Type::Union(entity));
204        }
205
206        let inner = SchemaInner {
207            env: SchemaEnv(Arc::new(SchemaEnvInner {
208                registry,
209                data: self.data,
210                custom_directives: Default::default(),
211            })),
212            extensions: self.extensions,
213            types: self.types,
214            recursive_depth: self.recursive_depth,
215            max_directives: self.max_directives,
216            complexity: self.complexity,
217            depth: self.depth,
218            validation_mode: self.validation_mode,
219            entity_resolver: self.entity_resolver,
220        };
221        inner.check()?;
222        Ok(Schema(Arc::new(inner)))
223    }
224}
225
226/// Dynamic GraphQL schema.
227///
228/// Cloning a schema is cheap, so it can be easily shared.
229#[derive(Clone)]
230pub struct Schema(pub(crate) Arc<SchemaInner>);
231
232impl Debug for Schema {
233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234        f.debug_struct("Schema").finish()
235    }
236}
237
238pub struct SchemaInner {
239    pub(crate) env: SchemaEnv,
240    pub(crate) types: IndexMap<String, Type>,
241    extensions: Vec<Box<dyn ExtensionFactory>>,
242    recursive_depth: usize,
243    max_directives: Option<usize>,
244    complexity: Option<usize>,
245    depth: Option<usize>,
246    validation_mode: ValidationMode,
247    pub(crate) entity_resolver: Option<BoxResolverFn>,
248}
249
250impl Schema {
251    /// Create a schema builder
252    pub fn build(query: &str, mutation: Option<&str>, subscription: Option<&str>) -> SchemaBuilder {
253        SchemaBuilder {
254            query_type: query.to_string(),
255            mutation_type: mutation.map(ToString::to_string),
256            subscription_type: subscription.map(ToString::to_string),
257            types: Default::default(),
258            data: Default::default(),
259            extensions: Default::default(),
260            validation_mode: ValidationMode::Strict,
261            recursive_depth: 32,
262            max_directives: None,
263            complexity: None,
264            depth: None,
265            enable_suggestions: true,
266            introspection_mode: IntrospectionMode::Enabled,
267            entity_resolver: None,
268            enable_federation: false,
269        }
270    }
271
272    fn create_extensions(&self, session_data: Arc<Data>) -> Extensions {
273        Extensions::new(
274            self.0.extensions.iter().map(|f| f.create()),
275            self.0.env.clone(),
276            session_data,
277        )
278    }
279
280    fn query_root(&self) -> ServerResult<&Object> {
281        self.0
282            .types
283            .get(&self.0.env.registry.query_type)
284            .and_then(Type::as_object)
285            .ok_or_else(|| ServerError::new("Query root not found", None))
286    }
287
288    fn mutation_root(&self) -> ServerResult<&Object> {
289        self.0
290            .env
291            .registry
292            .mutation_type
293            .as_ref()
294            .and_then(|mutation_name| self.0.types.get(mutation_name))
295            .and_then(Type::as_object)
296            .ok_or_else(|| ServerError::new("Mutation root not found", None))
297    }
298
299    fn subscription_root(&self) -> ServerResult<&Subscription> {
300        self.0
301            .env
302            .registry
303            .subscription_type
304            .as_ref()
305            .and_then(|subscription_name| self.0.types.get(subscription_name))
306            .and_then(Type::as_subscription)
307            .ok_or_else(|| ServerError::new("Subscription root not found", None))
308    }
309
310    /// Returns SDL(Schema Definition Language) of this schema.
311    pub fn sdl(&self) -> String {
312        self.0.env.registry.export_sdl(Default::default())
313    }
314
315    /// Returns SDL(Schema Definition Language) of this schema with options.
316    pub fn sdl_with_options(&self, options: SDLExportOptions) -> String {
317        self.0.env.registry.export_sdl(options)
318    }
319
320    async fn execute_once(
321        &self,
322        env: QueryEnv,
323        root_value: &FieldValue<'static>,
324        execute_data: Option<Data>,
325    ) -> Response {
326        // execute
327        let ctx = env.create_context(
328            &self.0.env,
329            None,
330            &env.operation.node.selection_set,
331            execute_data.as_ref(),
332        );
333        let res = match &env.operation.node.ty {
334            OperationType::Query => {
335                async move { self.query_root() }
336                    .and_then(|query_root| {
337                        resolve_container(self, query_root, &ctx, root_value, false)
338                    })
339                    .await
340            }
341            OperationType::Mutation => {
342                async move { self.mutation_root() }
343                    .and_then(|query_root| {
344                        resolve_container(self, query_root, &ctx, root_value, true)
345                    })
346                    .await
347            }
348            OperationType::Subscription => Err(ServerError::new(
349                "Subscriptions are not supported on this transport.",
350                None,
351            )),
352        };
353
354        let mut resp = match res {
355            Ok(value) => Response::new(value.unwrap_or_default()),
356            Err(err) => Response::from_errors(vec![err]),
357        }
358        .http_headers(std::mem::take(&mut *env.http_headers.lock().unwrap()));
359
360        resp.errors
361            .extend(std::mem::take(&mut *env.errors.lock().unwrap()));
362        resp
363    }
364
365    /// Execute a GraphQL query.
366    pub async fn execute(&self, request: impl Into<DynamicRequest>) -> Response {
367        let request = request.into();
368        let extensions = self.create_extensions(Default::default());
369        let request_fut = {
370            let extensions = extensions.clone();
371            async move {
372                match prepare_request(
373                    extensions,
374                    request.inner,
375                    Default::default(),
376                    &self.0.env.registry,
377                    self.0.validation_mode,
378                    self.0.recursive_depth,
379                    self.0.max_directives,
380                    self.0.complexity,
381                    self.0.depth,
382                )
383                .await
384                {
385                    Ok((env, cache_control)) => {
386                        let f = {
387                            |execute_data| {
388                                let env = env.clone();
389                                async move {
390                                    self.execute_once(env, &request.root_value, execute_data)
391                                        .await
392                                        .cache_control(cache_control)
393                                }
394                            }
395                        };
396                        env.extensions
397                            .execute(env.operation_name.as_deref(), f)
398                            .await
399                    }
400                    Err(errors) => Response::from_errors(errors),
401                }
402            }
403        };
404        futures_util::pin_mut!(request_fut);
405        extensions.request(&mut request_fut).await
406    }
407
408    /// Execute a GraphQL subscription with session data.
409    pub fn execute_stream_with_session_data(
410        &self,
411        request: impl Into<DynamicRequest>,
412        session_data: Arc<Data>,
413    ) -> BoxStream<'static, Response> {
414        let schema = self.clone();
415        let request = request.into();
416        let extensions = self.create_extensions(session_data.clone());
417
418        let stream = {
419            let extensions = extensions.clone();
420
421            asynk_strim::stream_fn(|mut yielder| async move {
422                let subscription = match schema.subscription_root() {
423                    Ok(subscription) => subscription,
424                    Err(err) => {
425                        yielder.yield_item(Response::from_errors(vec![err])).await;
426                        return;
427                    }
428                };
429
430                let (env, _) = match prepare_request(
431                    extensions,
432                    request.inner,
433                    session_data,
434                    &schema.0.env.registry,
435                    schema.0.validation_mode,
436                    schema.0.recursive_depth,
437                    schema.0.max_directives,
438                    schema.0.complexity,
439                    schema.0.depth,
440                )
441                .await
442                {
443                    Ok(res) => res,
444                    Err(errors) => {
445                        yielder.yield_item(Response::from_errors(errors)).await;
446                        return;
447                    }
448                };
449
450                if env.operation.node.ty != OperationType::Subscription {
451                    yielder
452                        .yield_item(schema.execute_once(env, &request.root_value, None).await)
453                        .await;
454                    return;
455                }
456
457                let ctx = env.create_context(
458                    &schema.0.env,
459                    None,
460                    &env.operation.node.selection_set,
461                    None,
462                );
463                let mut streams = Vec::new();
464                subscription.collect_streams(&schema, &ctx, &mut streams, &request.root_value);
465
466                let mut stream = futures_util::stream::select_all(streams);
467                while let Some(resp) = stream.next().await {
468                    yielder.yield_item(resp).await;
469                }
470            })
471        };
472        extensions.subscribe(stream.boxed())
473    }
474
475    /// Execute a GraphQL subscription.
476    pub fn execute_stream(
477        &self,
478        request: impl Into<DynamicRequest>,
479    ) -> BoxStream<'static, Response> {
480        self.execute_stream_with_session_data(request, Default::default())
481    }
482
483    /// Returns the registry of this schema.
484    pub fn registry(&self) -> &Registry {
485        &self.0.env.registry
486    }
487}
488
489#[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
490impl Executor for Schema {
491    async fn execute(&self, request: Request) -> Response {
492        Schema::execute(self, request).await
493    }
494
495    fn execute_stream(
496        &self,
497        request: Request,
498        session_data: Option<Arc<Data>>,
499    ) -> BoxStream<'static, Response> {
500        Schema::execute_stream_with_session_data(self, request, session_data.unwrap_or_default())
501    }
502}
503
504fn update_interface_possible_types(types: &mut IndexMap<String, Type>, registry: &mut Registry) {
505    let mut interfaces = registry
506        .types
507        .values_mut()
508        .filter_map(|ty| match ty {
509            MetaType::Interface {
510                name,
511                possible_types,
512                ..
513            } => Some((name, possible_types)),
514            _ => None,
515        })
516        .collect::<HashMap<_, _>>();
517
518    let objs = types.values().filter_map(|ty| match ty {
519        Type::Object(obj) => Some((&obj.name, &obj.implements)),
520        _ => None,
521    });
522
523    for (obj_name, implements) in objs {
524        for interface in implements {
525            if let Some(possible_types) = interfaces.get_mut(interface) {
526                possible_types.insert(obj_name.clone());
527            }
528        }
529    }
530}
531
532#[cfg(test)]
533mod tests {
534    use std::sync::Arc;
535
536    use async_graphql_parser::{Pos, types::ExecutableDocument};
537    use async_graphql_value::Variables;
538    use futures_util::{StreamExt, stream::BoxStream};
539    use tokio::sync::Mutex;
540
541    use crate::{
542        PathSegment, Request, Response, ServerError, ServerResult, ValidationResult, Value,
543        dynamic::*, extensions::*, value,
544    };
545
546    #[tokio::test]
547    async fn basic_query() {
548        let myobj = Object::new("MyObj")
549            .field(Field::new("a", TypeRef::named(TypeRef::INT), |_| {
550                FieldFuture::new(async { Ok(Some(Value::from(123))) })
551            }))
552            .field(Field::new("b", TypeRef::named(TypeRef::STRING), |_| {
553                FieldFuture::new(async { Ok(Some(Value::from("abc"))) })
554            }));
555
556        let query = Object::new("Query")
557            .field(Field::new("value", TypeRef::named(TypeRef::INT), |_| {
558                FieldFuture::new(async { Ok(Some(Value::from(100))) })
559            }))
560            .field(Field::new(
561                "valueObj",
562                TypeRef::named_nn(myobj.type_name()),
563                |_| FieldFuture::new(async { Ok(Some(FieldValue::NULL)) }),
564            ));
565        let schema = Schema::build("Query", None, None)
566            .register(query)
567            .register(myobj)
568            .finish()
569            .unwrap();
570
571        assert_eq!(
572            schema
573                .execute("{ value valueObj { a b } }")
574                .await
575                .into_result()
576                .unwrap()
577                .data,
578            value!({
579                "value": 100,
580                "valueObj": {
581                    "a": 123,
582                    "b": "abc",
583                }
584            })
585        );
586    }
587
588    #[tokio::test]
589    async fn root_value() {
590        let query =
591            Object::new("Query").field(Field::new("value", TypeRef::named(TypeRef::INT), |ctx| {
592                FieldFuture::new(async {
593                    Ok(Some(Value::Number(
594                        (*ctx.parent_value.try_downcast_ref::<i32>()?).into(),
595                    )))
596                })
597            }));
598
599        let schema = Schema::build("Query", None, None)
600            .register(query)
601            .finish()
602            .unwrap();
603        assert_eq!(
604            schema
605                .execute("{ value }".root_value(FieldValue::owned_any(100)))
606                .await
607                .into_result()
608                .unwrap()
609                .data,
610            value!({ "value": 100, })
611        );
612    }
613
614    #[tokio::test]
615    async fn field_alias() {
616        let query =
617            Object::new("Query").field(Field::new("value", TypeRef::named(TypeRef::INT), |_| {
618                FieldFuture::new(async { Ok(Some(Value::from(100))) })
619            }));
620        let schema = Schema::build("Query", None, None)
621            .register(query)
622            .finish()
623            .unwrap();
624
625        assert_eq!(
626            schema
627                .execute("{ a: value }")
628                .await
629                .into_result()
630                .unwrap()
631                .data,
632            value!({
633                "a": 100,
634            })
635        );
636    }
637
638    #[tokio::test]
639    async fn fragment_spread() {
640        let myobj = Object::new("MyObj")
641            .field(Field::new("a", TypeRef::named(TypeRef::INT), |_| {
642                FieldFuture::new(async { Ok(Some(Value::from(123))) })
643            }))
644            .field(Field::new("b", TypeRef::named(TypeRef::STRING), |_| {
645                FieldFuture::new(async { Ok(Some(Value::from("abc"))) })
646            }));
647
648        let query = Object::new("Query").field(Field::new(
649            "valueObj",
650            TypeRef::named_nn(myobj.type_name()),
651            |_| FieldFuture::new(async { Ok(Some(Value::Null)) }),
652        ));
653        let schema = Schema::build("Query", None, None)
654            .register(query)
655            .register(myobj)
656            .finish()
657            .unwrap();
658
659        let query = r#"
660            fragment A on MyObj {
661                a b
662            }
663
664            { valueObj { ... A } }
665            "#;
666
667        assert_eq!(
668            schema.execute(query).await.into_result().unwrap().data,
669            value!({
670                "valueObj": {
671                    "a": 123,
672                    "b": "abc",
673                }
674            })
675        );
676    }
677
678    #[tokio::test]
679    async fn inline_fragment() {
680        let myobj = Object::new("MyObj")
681            .field(Field::new("a", TypeRef::named(TypeRef::INT), |_| {
682                FieldFuture::new(async { Ok(Some(Value::from(123))) })
683            }))
684            .field(Field::new("b", TypeRef::named(TypeRef::STRING), |_| {
685                FieldFuture::new(async { Ok(Some(Value::from("abc"))) })
686            }));
687
688        let query = Object::new("Query").field(Field::new(
689            "valueObj",
690            TypeRef::named_nn(myobj.type_name()),
691            |_| FieldFuture::new(async { Ok(Some(FieldValue::NULL)) }),
692        ));
693        let schema = Schema::build("Query", None, None)
694            .register(query)
695            .register(myobj)
696            .finish()
697            .unwrap();
698
699        let query = r#"
700            {
701                valueObj {
702                     ... on MyObj { a }
703                     ... { b }
704                }
705            }
706            "#;
707
708        assert_eq!(
709            schema.execute(query).await.into_result().unwrap().data,
710            value!({
711                "valueObj": {
712                    "a": 123,
713                    "b": "abc",
714                }
715            })
716        );
717    }
718
719    #[tokio::test]
720    async fn non_null() {
721        let query = Object::new("Query")
722            .field(Field::new(
723                "valueA",
724                TypeRef::named_nn(TypeRef::INT),
725                |_| FieldFuture::new(async { Ok(FieldValue::none()) }),
726            ))
727            .field(Field::new(
728                "valueB",
729                TypeRef::named_nn(TypeRef::INT),
730                |_| FieldFuture::new(async { Ok(Some(Value::from(100))) }),
731            ))
732            .field(Field::new("valueC", TypeRef::named(TypeRef::INT), |_| {
733                FieldFuture::new(async { Ok(FieldValue::none()) })
734            }))
735            .field(Field::new("valueD", TypeRef::named(TypeRef::INT), |_| {
736                FieldFuture::new(async { Ok(Some(Value::from(200))) })
737            }));
738        let schema = Schema::build("Query", None, None)
739            .register(query)
740            .finish()
741            .unwrap();
742
743        assert_eq!(
744            schema
745                .execute("{ valueA }")
746                .await
747                .into_result()
748                .unwrap_err(),
749            vec![ServerError {
750                message: "internal: non-null types require a return value".to_owned(),
751                source: None,
752                locations: vec![Pos { column: 3, line: 1 }],
753                path: vec![PathSegment::Field("valueA".to_owned())],
754                extensions: None,
755            }]
756        );
757
758        assert_eq!(
759            schema
760                .execute("{ valueB }")
761                .await
762                .into_result()
763                .unwrap()
764                .data,
765            value!({
766                "valueB": 100
767            })
768        );
769
770        assert_eq!(
771            schema
772                .execute("{ valueC valueD }")
773                .await
774                .into_result()
775                .unwrap()
776                .data,
777            value!({
778                "valueC": null,
779                "valueD": 200,
780            })
781        );
782    }
783
784    #[tokio::test]
785    async fn list() {
786        let query = Object::new("Query")
787            .field(Field::new(
788                "values",
789                TypeRef::named_nn_list_nn(TypeRef::INT),
790                |_| {
791                    FieldFuture::new(async {
792                        Ok(Some(vec![Value::from(3), Value::from(6), Value::from(9)]))
793                    })
794                },
795            ))
796            .field(Field::new(
797                "values2",
798                TypeRef::named_nn_list_nn(TypeRef::INT),
799                |_| {
800                    FieldFuture::new(async {
801                        Ok(Some(Value::List(vec![
802                            Value::from(3),
803                            Value::from(6),
804                            Value::from(9),
805                        ])))
806                    })
807                },
808            ))
809            .field(Field::new(
810                "values3",
811                TypeRef::named_nn_list(TypeRef::INT),
812                |_| FieldFuture::new(async { Ok(None::<Vec<Value>>) }),
813            ));
814        let schema = Schema::build("Query", None, None)
815            .register(query)
816            .finish()
817            .unwrap();
818
819        assert_eq!(
820            schema
821                .execute("{ values values2 values3 }")
822                .await
823                .into_result()
824                .unwrap()
825                .data,
826            value!({
827                "values": [3, 6, 9],
828                "values2": [3, 6, 9],
829                "values3": null,
830            })
831        );
832    }
833
834    #[tokio::test]
835    async fn extensions() {
836        struct MyExtensionImpl {
837            calls: Arc<Mutex<Vec<&'static str>>>,
838        }
839
840        #[async_trait::async_trait]
841        #[allow(unused_variables)]
842        impl Extension for MyExtensionImpl {
843            async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response {
844                self.calls.lock().await.push("request_start");
845                let res = next.run(ctx).await;
846                self.calls.lock().await.push("request_end");
847                res
848            }
849
850            fn subscribe<'s>(
851                &self,
852                ctx: &ExtensionContext<'_>,
853                mut stream: BoxStream<'s, Response>,
854                next: NextSubscribe<'_>,
855            ) -> BoxStream<'s, Response> {
856                let calls = self.calls.clone();
857                next.run(
858                    ctx,
859                    Box::pin(asynk_strim::stream_fn(|mut yielder| async move {
860                        calls.lock().await.push("subscribe_start");
861                        while let Some(item) = stream.next().await {
862                            yielder.yield_item(item).await;
863                        }
864                        calls.lock().await.push("subscribe_end");
865                    })),
866                )
867            }
868
869            async fn prepare_request(
870                &self,
871                ctx: &ExtensionContext<'_>,
872                request: Request,
873                next: NextPrepareRequest<'_>,
874            ) -> ServerResult<Request> {
875                self.calls.lock().await.push("prepare_request_start");
876                let res = next.run(ctx, request).await;
877                self.calls.lock().await.push("prepare_request_end");
878                res
879            }
880
881            async fn parse_query(
882                &self,
883                ctx: &ExtensionContext<'_>,
884                query: &str,
885                variables: &Variables,
886                next: NextParseQuery<'_>,
887            ) -> ServerResult<ExecutableDocument> {
888                self.calls.lock().await.push("parse_query_start");
889                let res = next.run(ctx, query, variables).await;
890                self.calls.lock().await.push("parse_query_end");
891                res
892            }
893
894            async fn validation(
895                &self,
896                ctx: &ExtensionContext<'_>,
897                next: NextValidation<'_>,
898            ) -> Result<ValidationResult, Vec<ServerError>> {
899                self.calls.lock().await.push("validation_start");
900                let res = next.run(ctx).await;
901                self.calls.lock().await.push("validation_end");
902                res
903            }
904
905            async fn execute(
906                &self,
907                ctx: &ExtensionContext<'_>,
908                operation_name: Option<&str>,
909                next: NextExecute<'_>,
910            ) -> Response {
911                assert_eq!(operation_name, Some("Abc"));
912                self.calls.lock().await.push("execute_start");
913                let res = next.run(ctx, operation_name).await;
914                self.calls.lock().await.push("execute_end");
915                res
916            }
917
918            async fn resolve(
919                &self,
920                ctx: &ExtensionContext<'_>,
921                info: ResolveInfo<'_>,
922                next: NextResolve<'_>,
923            ) -> ServerResult<Option<Value>> {
924                self.calls.lock().await.push("resolve_start");
925                let res = next.run(ctx, info).await;
926                self.calls.lock().await.push("resolve_end");
927                res
928            }
929        }
930
931        struct MyExtension {
932            calls: Arc<Mutex<Vec<&'static str>>>,
933        }
934
935        impl ExtensionFactory for MyExtension {
936            fn create(&self) -> Arc<dyn Extension> {
937                Arc::new(MyExtensionImpl {
938                    calls: self.calls.clone(),
939                })
940            }
941        }
942
943        {
944            let query = Object::new("Query")
945                .field(Field::new(
946                    "value1",
947                    TypeRef::named_nn(TypeRef::INT),
948                    |_| FieldFuture::new(async { Ok(Some(Value::from(10))) }),
949                ))
950                .field(Field::new(
951                    "value2",
952                    TypeRef::named_nn(TypeRef::INT),
953                    |_| FieldFuture::new(async { Ok(Some(Value::from(10))) }),
954                ));
955
956            let calls: Arc<Mutex<Vec<&'static str>>> = Default::default();
957            let schema = Schema::build(query.type_name(), None, None)
958                .register(query)
959                .extension(MyExtension {
960                    calls: calls.clone(),
961                })
962                .finish()
963                .unwrap();
964
965            let _ = schema
966                .execute("query Abc { value1 value2 }")
967                .await
968                .into_result()
969                .unwrap();
970            let calls = calls.lock().await;
971            assert_eq!(
972                &*calls,
973                &vec![
974                    "request_start",
975                    "prepare_request_start",
976                    "prepare_request_end",
977                    "parse_query_start",
978                    "parse_query_end",
979                    "validation_start",
980                    "validation_end",
981                    "execute_start",
982                    "resolve_start",
983                    "resolve_end",
984                    "resolve_start",
985                    "resolve_end",
986                    "execute_end",
987                    "request_end",
988                ]
989            );
990        }
991
992        {
993            let query = Object::new("Query").field(Field::new(
994                "value1",
995                TypeRef::named_nn(TypeRef::INT),
996                |_| FieldFuture::new(async { Ok(Some(Value::from(10))) }),
997            ));
998
999            let subscription = Subscription::new("Subscription").field(SubscriptionField::new(
1000                "value",
1001                TypeRef::named_nn(TypeRef::INT),
1002                |_| {
1003                    SubscriptionFieldFuture::new(async {
1004                        Ok(futures_util::stream::iter([1, 2, 3])
1005                            .map(|value| Ok(Value::from(value))))
1006                    })
1007                },
1008            ));
1009
1010            let calls: Arc<Mutex<Vec<&'static str>>> = Default::default();
1011            let schema = Schema::build(query.type_name(), None, Some(subscription.type_name()))
1012                .register(query)
1013                .register(subscription)
1014                .extension(MyExtension {
1015                    calls: calls.clone(),
1016                })
1017                .finish()
1018                .unwrap();
1019
1020            let mut stream = schema.execute_stream("subscription Abc { value }");
1021            while stream.next().await.is_some() {}
1022            let calls = calls.lock().await;
1023            assert_eq!(
1024                &*calls,
1025                &vec![
1026                    "subscribe_start",
1027                    "prepare_request_start",
1028                    "prepare_request_end",
1029                    "parse_query_start",
1030                    "parse_query_end",
1031                    "validation_start",
1032                    "validation_end",
1033                    // push 1
1034                    "execute_start",
1035                    "resolve_start",
1036                    "resolve_end",
1037                    "execute_end",
1038                    // push 2
1039                    "execute_start",
1040                    "resolve_start",
1041                    "resolve_end",
1042                    "execute_end",
1043                    // push 3
1044                    "execute_start",
1045                    "resolve_start",
1046                    "resolve_end",
1047                    "execute_end",
1048                    // end
1049                    "subscribe_end",
1050                ]
1051            );
1052        }
1053    }
1054
1055    #[tokio::test]
1056    async fn federation() {
1057        let user = Object::new("User")
1058            .field(Field::new(
1059                "name",
1060                TypeRef::named_nn(TypeRef::STRING),
1061                |_| FieldFuture::new(async { Ok(Some(FieldValue::value("test"))) }),
1062            ))
1063            .key("name");
1064
1065        let query =
1066            Object::new("Query").field(Field::new("value", TypeRef::named(TypeRef::INT), |_| {
1067                FieldFuture::new(async { Ok(Some(Value::from(100))) })
1068            }));
1069
1070        let schema = Schema::build("Query", None, None)
1071            .register(query)
1072            .register(user)
1073            .entity_resolver(|ctx| {
1074                FieldFuture::new(async move {
1075                    let representations = ctx.args.try_get("representations")?.list()?;
1076                    let mut values = Vec::new();
1077
1078                    for item in representations.iter() {
1079                        let item = item.object()?;
1080                        let typename = item
1081                            .try_get("__typename")
1082                            .and_then(|value| value.string())?;
1083
1084                        if typename == "User" {
1085                            values.push(FieldValue::borrowed_any(&()).with_type("User"));
1086                        }
1087                    }
1088
1089                    Ok(Some(FieldValue::list(values)))
1090                })
1091            })
1092            .finish()
1093            .unwrap();
1094
1095        assert_eq!(
1096            schema
1097                .execute(
1098                    r#"
1099                {
1100                    _entities(representations: [{__typename: "User", name: "test"}]) {
1101                        __typename
1102                        ... on User {
1103                            name
1104                        }
1105                    }
1106                }
1107                "#
1108                )
1109                .await
1110                .into_result()
1111                .unwrap()
1112                .data,
1113            value!({
1114                "_entities": [{
1115                    "__typename": "User",
1116                    "name": "test",
1117                }],
1118            })
1119        );
1120    }
1121}