async_graphql/dynamic/
check.rs

1use std::collections::HashSet;
2
3use indexmap::IndexMap;
4
5use crate::dynamic::{
6    InputObject, Interface, Object, SchemaError, Type,
7    base::{BaseContainer, BaseField},
8    schema::SchemaInner,
9    type_ref::TypeRef,
10};
11
12impl SchemaInner {
13    pub(crate) fn check(&self) -> Result<(), SchemaError> {
14        self.check_types_exists()?;
15        self.check_root_types()?;
16        self.check_objects()?;
17        self.check_input_objects()?;
18        self.check_interfaces()?;
19        self.check_unions()?;
20        Ok(())
21    }
22
23    fn check_root_types(&self) -> Result<(), SchemaError> {
24        if let Some(ty) = self.types.get(&self.env.registry.query_type) {
25            if !matches!(ty, Type::Object(_)) {
26                return Err("The query root must be an object".into());
27            }
28        }
29
30        if let Some(mutation_type) = &self.env.registry.mutation_type {
31            if let Some(ty) = self.types.get(mutation_type) {
32                if !matches!(ty, Type::Object(_)) {
33                    return Err("The mutation root must be an object".into());
34                }
35            }
36        }
37
38        if let Some(subscription_type) = &self.env.registry.subscription_type {
39            if let Some(ty) = self.types.get(subscription_type) {
40                if !matches!(ty, Type::Subscription(_)) {
41                    return Err("The subscription root must be a subscription object".into());
42                }
43            }
44        }
45
46        Ok(())
47    }
48
49    fn check_types_exists(&self) -> Result<(), SchemaError> {
50        fn check<I: IntoIterator<Item = T>, T: AsRef<str>>(
51            types: &IndexMap<String, Type>,
52            type_names: I,
53        ) -> Result<(), SchemaError> {
54            for name in type_names {
55                if !types.contains_key(name.as_ref()) {
56                    return Err(format!("Type \"{0}\" not found", name.as_ref()).into());
57                }
58            }
59            Ok(())
60        }
61
62        check(
63            &self.types,
64            std::iter::once(self.env.registry.query_type.as_str())
65                .chain(self.env.registry.mutation_type.as_deref()),
66        )?;
67
68        for ty in self.types.values() {
69            match ty {
70                Type::Object(obj) => check(
71                    &self.types,
72                    obj.fields
73                        .values()
74                        .map(|field| {
75                            std::iter::once(field.ty.type_name())
76                                .chain(field.arguments.values().map(|arg| arg.ty.type_name()))
77                        })
78                        .flatten()
79                        .chain(obj.implements.iter().map(AsRef::as_ref)),
80                )?,
81                Type::InputObject(obj) => {
82                    check(
83                        &self.types,
84                        obj.fields.values().map(|field| field.ty.type_name()),
85                    )?;
86                }
87                Type::Interface(interface) => check(
88                    &self.types,
89                    interface
90                        .fields
91                        .values()
92                        .map(|field| {
93                            std::iter::once(field.ty.type_name())
94                                .chain(field.arguments.values().map(|arg| arg.ty.type_name()))
95                        })
96                        .flatten(),
97                )?,
98                Type::Union(union) => check(&self.types, &union.possible_types)?,
99                Type::Subscription(subscription) => check(
100                    &self.types,
101                    subscription
102                        .fields
103                        .values()
104                        .map(|field| {
105                            std::iter::once(field.ty.type_name())
106                                .chain(field.arguments.values().map(|arg| arg.ty.type_name()))
107                        })
108                        .flatten(),
109                )?,
110                Type::Scalar(_) | Type::Enum(_) | Type::Upload => {}
111            }
112        }
113
114        Ok(())
115    }
116
117    fn check_objects(&self) -> Result<(), SchemaError> {
118        let has_entities = self
119            .types
120            .iter()
121            .filter_map(|(_, ty)| ty.as_object())
122            .any(Object::is_entity);
123
124        // https://spec.graphql.org/October2021/#sec-Objects.Type-Validation
125        for ty in self.types.values() {
126            if let Type::Object(obj) = ty {
127                // An Object type must define one or more fields.
128                if obj.fields.is_empty()
129                    && !(obj.type_name() == self.env.registry.query_type && has_entities)
130                {
131                    return Err(
132                        format!("Object \"{}\" must define one or more fields", obj.name).into(),
133                    );
134                }
135
136                for field in obj.fields.values() {
137                    // The field must not have a name which begins with the characters "__" (two
138                    // underscores)
139                    if field.name.starts_with("__") {
140                        return Err(format!("Field \"{}.{}\" must not have a name which begins with the characters \"__\" (two underscores)", obj.name, field.name).into());
141                    }
142
143                    // The field must return a type where IsOutputType(fieldType) returns true.
144                    if let Some(ty) = self.types.get(field.ty.type_name()) {
145                        if !ty.is_output_type() {
146                            return Err(format!(
147                                "Field \"{}.{}\" must return a output type",
148                                obj.name, field.name
149                            )
150                            .into());
151                        }
152                    }
153
154                    for arg in field.arguments.values() {
155                        // The argument must not have a name which begins with the characters "__"
156                        // (two underscores).
157                        if arg.name.starts_with("__") {
158                            return Err(format!("Argument \"{}.{}.{}\" must not have a name which begins with the characters \"__\" (two underscores)", obj.name, field.name, arg.name).into());
159                        }
160
161                        // The argument must accept a type where
162                        // IsInputType(argumentType) returns true.
163                        if let Some(ty) = self.types.get(arg.ty.type_name()) {
164                            if !ty.is_input_type() {
165                                return Err(format!(
166                                    "Argument \"{}.{}.{}\" must accept a input type",
167                                    obj.name, field.name, arg.name
168                                )
169                                .into());
170                            }
171                        }
172                    }
173                }
174
175                for interface_name in &obj.implements {
176                    if let Some(ty) = self.types.get(interface_name) {
177                        let interface = ty.as_interface().ok_or_else(|| {
178                            format!("Type \"{}\" is not interface", interface_name)
179                        })?;
180                        check_is_valid_implementation(obj, interface)?;
181                    }
182                }
183            }
184        }
185
186        Ok(())
187    }
188
189    fn check_input_objects(&self) -> Result<(), SchemaError> {
190        // https://spec.graphql.org/October2021/#sec-Input-Objects.Type-Validation
191        for ty in self.types.values() {
192            if let Type::InputObject(obj) = ty {
193                for field in obj.fields.values() {
194                    // The field must not have a name which begins with the characters "__" (two
195                    // underscores)
196                    if field.name.starts_with("__") {
197                        return Err(format!("Field \"{}.{}\" must not have a name which begins with the characters \"__\" (two underscores)", obj.name, field.name).into());
198                    }
199
200                    // The input field must accept a type where IsInputType(inputFieldType) returns
201                    // true.
202                    if let Some(ty) = self.types.get(field.ty.type_name()) {
203                        if !ty.is_input_type() {
204                            return Err(format!(
205                                "Field \"{}.{}\" must accept a input type",
206                                obj.name, field.name
207                            )
208                            .into());
209                        }
210                    }
211
212                    if obj.oneof {
213                        // The type of the input field must be nullable.
214                        if !field.ty.is_nullable() {
215                            return Err(format!(
216                                "Field \"{}.{}\" must be nullable",
217                                obj.name, field.name
218                            )
219                            .into());
220                        }
221
222                        // The input field must not have a default value.
223                        if field.default_value.is_some() {
224                            return Err(format!(
225                                "Field \"{}.{}\" must not have a default value",
226                                obj.name, field.name
227                            )
228                            .into());
229                        }
230                    }
231                }
232
233                // If an Input Object references itself either directly or
234                // through referenced Input Objects, at least one of the
235                // fields in the chain of references must be either a
236                // nullable or a List type.
237                self.check_input_object_reference(&obj.name, &obj, &mut HashSet::new())?;
238            }
239        }
240
241        Ok(())
242    }
243
244    fn check_input_object_reference<'a>(
245        &'a self,
246        current: &str,
247        obj: &'a InputObject,
248        ref_chain: &mut HashSet<&'a str>,
249    ) -> Result<(), SchemaError> {
250        fn typeref_nonnullable_name(ty: &TypeRef) -> Option<&str> {
251            match ty {
252                TypeRef::NonNull(inner) => match inner.as_ref() {
253                    TypeRef::Named(name) => Some(name),
254                    _ => None,
255                },
256                _ => None,
257            }
258        }
259
260        for field in obj.fields.values() {
261            if let Some(this_name) = typeref_nonnullable_name(&field.ty) {
262                if this_name == current {
263                    return Err(format!("\"{}\" references itself either directly or through referenced Input Objects, at least one of the fields in the chain of references must be either a nullable or a List type.", current).into());
264                } else if let Some(obj) = self
265                    .types
266                    .get(field.ty.type_name())
267                    .and_then(Type::as_input_object)
268                {
269                    // don't visit the reference if we've already visited it in this call chain
270                    //  (prevents getting stuck in local cycles and overflowing stack)
271                    //  true return from insert indicates the value was not previously there
272                    if ref_chain.insert(this_name) {
273                        self.check_input_object_reference(current, obj, ref_chain)?;
274                        ref_chain.remove(this_name);
275                    }
276                }
277            }
278        }
279
280        Ok(())
281    }
282
283    fn check_interfaces(&self) -> Result<(), SchemaError> {
284        // https://spec.graphql.org/October2021/#sec-Interfaces.Type-Validation
285        for ty in self.types.values() {
286            if let Type::Interface(interface) = ty {
287                for field in interface.fields.values() {
288                    // The field must not have a name which begins with the characters "__" (two
289                    // underscores)
290                    if field.name.starts_with("__") {
291                        return Err(format!("Field \"{}.{}\" must not have a name which begins with the characters \"__\" (two underscores)", interface.name, field.name).into());
292                    }
293
294                    // The field must return a type where IsOutputType(fieldType) returns true.
295                    if let Some(ty) = self.types.get(field.ty.type_name()) {
296                        if !ty.is_output_type() {
297                            return Err(format!(
298                                "Field \"{}.{}\" must return a output type",
299                                interface.name, field.name
300                            )
301                            .into());
302                        }
303                    }
304
305                    for arg in field.arguments.values() {
306                        // The argument must not have a name which begins with the characters "__"
307                        // (two underscores).
308                        if arg.name.starts_with("__") {
309                            return Err(format!("Argument \"{}.{}.{}\" must not have a name which begins with the characters \"__\" (two underscores)", interface.name, field.name, arg.name).into());
310                        }
311
312                        // The argument must accept a type where
313                        // IsInputType(argumentType) returns true.
314                        if let Some(ty) = self.types.get(arg.ty.type_name()) {
315                            if !ty.is_input_type() {
316                                return Err(format!(
317                                    "Argument \"{}.{}.{}\" must accept a input type",
318                                    interface.name, field.name, arg.name
319                                )
320                                .into());
321                            }
322                        }
323                    }
324
325                    // An interface type may declare that it implements one or more unique
326                    // interfaces, but may not implement itself.
327                    if interface.implements.contains(&interface.name) {
328                        return Err(format!(
329                            "Interface \"{}\" may not implement itself",
330                            interface.name
331                        )
332                        .into());
333                    }
334
335                    // An interface type must be a super-set of all interfaces
336                    // it implements
337                    for interface_name in &interface.implements {
338                        if let Some(ty) = self.types.get(interface_name) {
339                            let implemenented_type = ty.as_interface().ok_or_else(|| {
340                                format!("Type \"{}\" is not interface", interface_name)
341                            })?;
342                            check_is_valid_implementation(interface, implemenented_type)?;
343                        }
344                    }
345                }
346            }
347        }
348
349        Ok(())
350    }
351
352    fn check_unions(&self) -> Result<(), SchemaError> {
353        // https://spec.graphql.org/October2021/#sec-Unions.Type-Validation
354        for ty in self.types.values() {
355            if let Type::Union(union) = ty {
356                // The member types of a Union type must all be Object base
357                // types; Scalar, Interface and Union types must not be member
358                // types of a Union. Similarly, wrapping types must not be
359                // member types of a Union.
360                for type_name in &union.possible_types {
361                    if let Some(ty) = self.types.get(type_name) {
362                        if ty.as_object().is_none() {
363                            return Err(format!(
364                                "Member \"{}\" of union \"{}\" is not an object",
365                                type_name, union.name
366                            )
367                            .into());
368                        }
369                    }
370                }
371            }
372        }
373
374        Ok(())
375    }
376}
377
378fn check_is_valid_implementation(
379    implementing_type: &impl BaseContainer,
380    implemented_type: &Interface,
381) -> Result<(), SchemaError> {
382    for field in implemented_type.fields.values() {
383        let impl_field = implementing_type.field(&field.name).ok_or_else(|| {
384            format!(
385                "{} \"{}\" requires field \"{}\" defined by interface \"{}\"",
386                implementing_type.graphql_type(),
387                implementing_type.name(),
388                field.name,
389                implemented_type.name
390            )
391        })?;
392
393        for arg in field.arguments.values() {
394            let impl_arg = match impl_field.argument(&arg.name) {
395                Some(impl_arg) => impl_arg,
396                None if !arg.ty.is_nullable() => {
397                    return Err(format!(
398                        "Field \"{}.{}\" requires argument \"{}\" defined by interface \"{}.{}\"",
399                        implementing_type.name(),
400                        field.name,
401                        arg.name,
402                        implemented_type.name,
403                        field.name,
404                    )
405                    .into());
406                }
407                None => continue,
408            };
409
410            if !arg.ty.is_subtype(&impl_arg.ty) {
411                return Err(format!(
412                    "Argument \"{}.{}.{}\" is not sub-type of \"{}.{}.{}\"",
413                    implemented_type.name,
414                    field.name,
415                    arg.name,
416                    implementing_type.name(),
417                    field.name,
418                    arg.name
419                )
420                .into());
421            }
422        }
423
424        // field must return a type which is equal to or a sub-type of (covariant) the
425        // return type of implementedField field’s return type
426        if !impl_field.ty().is_subtype(&field.ty) {
427            return Err(format!(
428                "Field \"{}.{}\" is not sub-type of \"{}.{}\"",
429                implementing_type.name(),
430                field.name,
431                implemented_type.name,
432                field.name,
433            )
434            .into());
435        }
436    }
437
438    Ok(())
439}
440
441#[cfg(test)]
442mod tests {
443    use crate::{
444        Value,
445        dynamic::{
446            Field, FieldFuture, InputObject, InputValue, Object, Schema, SchemaBuilder, TypeRef,
447        },
448    };
449
450    fn base_schema() -> SchemaBuilder {
451        let query = Object::new("Query").field(Field::new("dummy", TypeRef::named("Int"), |_| {
452            FieldFuture::new(async { Ok(Some(Value::from(42))) })
453        }));
454        Schema::build("Query", None, None).register(query)
455    }
456
457    #[test]
458    fn test_recursive_input_objects() {
459        let top_level = InputObject::new("TopLevel")
460            .field(InputValue::new("mid", TypeRef::named_nn("MidLevel")));
461        let mid_level = InputObject::new("MidLevel")
462            .field(InputValue::new("bottom", TypeRef::named("BotLevel")))
463            .field(InputValue::new(
464                "list_bottom",
465                TypeRef::named_nn_list_nn("BotLevel"),
466            ));
467        let bot_level = InputObject::new("BotLevel")
468            .field(InputValue::new("top", TypeRef::named_nn("TopLevel")));
469        let schema = base_schema()
470            .register(top_level)
471            .register(mid_level)
472            .register(bot_level);
473        schema.finish().unwrap();
474    }
475
476    #[test]
477    fn test_recursive_input_objects_bad() {
478        let top_level = InputObject::new("TopLevel")
479            .field(InputValue::new("mid", TypeRef::named_nn("MidLevel")));
480        let mid_level = InputObject::new("MidLevel")
481            .field(InputValue::new("bottom", TypeRef::named_nn("BotLevel")));
482        let bot_level = InputObject::new("BotLevel")
483            .field(InputValue::new("top", TypeRef::named_nn("TopLevel")));
484        let schema = base_schema()
485            .register(top_level)
486            .register(mid_level)
487            .register(bot_level);
488        schema.finish().unwrap_err();
489    }
490
491    #[test]
492    fn test_recursive_input_objects_local_cycle() {
493        let top_level = InputObject::new("TopLevel")
494            .field(InputValue::new("mid", TypeRef::named_nn("MidLevel")));
495        let mid_level = InputObject::new("MidLevel")
496            .field(InputValue::new("bottom", TypeRef::named_nn("BotLevel")));
497        let bot_level = InputObject::new("BotLevel")
498            .field(InputValue::new("mid", TypeRef::named_nn("MidLevel")));
499        let schema = base_schema()
500            .register(top_level)
501            .register(mid_level)
502            .register(bot_level);
503        schema.finish().unwrap_err();
504    }
505}