1use std::collections::HashSet;
2
3use indexmap::IndexMap;
4
5use crate::dynamic::{
6 InputObject, Interface, Object, SchemaError, Type,
7 base::{BaseContainer, BaseField},
8 schema::SchemaInner,
9 type_ref::TypeRef,
10};
11
12impl SchemaInner {
13 pub(crate) fn check(&self) -> Result<(), SchemaError> {
14 self.check_types_exists()?;
15 self.check_root_types()?;
16 self.check_objects()?;
17 self.check_input_objects()?;
18 self.check_interfaces()?;
19 self.check_unions()?;
20 Ok(())
21 }
22
23 fn check_root_types(&self) -> Result<(), SchemaError> {
24 if let Some(ty) = self.types.get(&self.env.registry.query_type) {
25 if !matches!(ty, Type::Object(_)) {
26 return Err("The query root must be an object".into());
27 }
28 }
29
30 if let Some(mutation_type) = &self.env.registry.mutation_type {
31 if let Some(ty) = self.types.get(mutation_type) {
32 if !matches!(ty, Type::Object(_)) {
33 return Err("The mutation root must be an object".into());
34 }
35 }
36 }
37
38 if let Some(subscription_type) = &self.env.registry.subscription_type {
39 if let Some(ty) = self.types.get(subscription_type) {
40 if !matches!(ty, Type::Subscription(_)) {
41 return Err("The subscription root must be a subscription object".into());
42 }
43 }
44 }
45
46 Ok(())
47 }
48
49 fn check_types_exists(&self) -> Result<(), SchemaError> {
50 fn check<I: IntoIterator<Item = T>, T: AsRef<str>>(
51 types: &IndexMap<String, Type>,
52 type_names: I,
53 ) -> Result<(), SchemaError> {
54 for name in type_names {
55 if !types.contains_key(name.as_ref()) {
56 return Err(format!("Type \"{0}\" not found", name.as_ref()).into());
57 }
58 }
59 Ok(())
60 }
61
62 check(
63 &self.types,
64 std::iter::once(self.env.registry.query_type.as_str())
65 .chain(self.env.registry.mutation_type.as_deref()),
66 )?;
67
68 for ty in self.types.values() {
69 match ty {
70 Type::Object(obj) => check(
71 &self.types,
72 obj.fields
73 .values()
74 .map(|field| {
75 std::iter::once(field.ty.type_name())
76 .chain(field.arguments.values().map(|arg| arg.ty.type_name()))
77 })
78 .flatten()
79 .chain(obj.implements.iter().map(AsRef::as_ref)),
80 )?,
81 Type::InputObject(obj) => {
82 check(
83 &self.types,
84 obj.fields.values().map(|field| field.ty.type_name()),
85 )?;
86 }
87 Type::Interface(interface) => check(
88 &self.types,
89 interface
90 .fields
91 .values()
92 .map(|field| {
93 std::iter::once(field.ty.type_name())
94 .chain(field.arguments.values().map(|arg| arg.ty.type_name()))
95 })
96 .flatten(),
97 )?,
98 Type::Union(union) => check(&self.types, &union.possible_types)?,
99 Type::Subscription(subscription) => check(
100 &self.types,
101 subscription
102 .fields
103 .values()
104 .map(|field| {
105 std::iter::once(field.ty.type_name())
106 .chain(field.arguments.values().map(|arg| arg.ty.type_name()))
107 })
108 .flatten(),
109 )?,
110 Type::Scalar(_) | Type::Enum(_) | Type::Upload => {}
111 }
112 }
113
114 Ok(())
115 }
116
117 fn check_objects(&self) -> Result<(), SchemaError> {
118 let has_entities = self
119 .types
120 .iter()
121 .filter_map(|(_, ty)| ty.as_object())
122 .any(Object::is_entity);
123
124 for ty in self.types.values() {
126 if let Type::Object(obj) = ty {
127 if obj.fields.is_empty()
129 && !(obj.type_name() == self.env.registry.query_type && has_entities)
130 {
131 return Err(
132 format!("Object \"{}\" must define one or more fields", obj.name).into(),
133 );
134 }
135
136 for field in obj.fields.values() {
137 if field.name.starts_with("__") {
140 return Err(format!("Field \"{}.{}\" must not have a name which begins with the characters \"__\" (two underscores)", obj.name, field.name).into());
141 }
142
143 if let Some(ty) = self.types.get(field.ty.type_name()) {
145 if !ty.is_output_type() {
146 return Err(format!(
147 "Field \"{}.{}\" must return a output type",
148 obj.name, field.name
149 )
150 .into());
151 }
152 }
153
154 for arg in field.arguments.values() {
155 if arg.name.starts_with("__") {
158 return Err(format!("Argument \"{}.{}.{}\" must not have a name which begins with the characters \"__\" (two underscores)", obj.name, field.name, arg.name).into());
159 }
160
161 if let Some(ty) = self.types.get(arg.ty.type_name()) {
164 if !ty.is_input_type() {
165 return Err(format!(
166 "Argument \"{}.{}.{}\" must accept a input type",
167 obj.name, field.name, arg.name
168 )
169 .into());
170 }
171 }
172 }
173 }
174
175 for interface_name in &obj.implements {
176 if let Some(ty) = self.types.get(interface_name) {
177 let interface = ty.as_interface().ok_or_else(|| {
178 format!("Type \"{}\" is not interface", interface_name)
179 })?;
180 check_is_valid_implementation(obj, interface)?;
181 }
182 }
183 }
184 }
185
186 Ok(())
187 }
188
189 fn check_input_objects(&self) -> Result<(), SchemaError> {
190 for ty in self.types.values() {
192 if let Type::InputObject(obj) = ty {
193 for field in obj.fields.values() {
194 if field.name.starts_with("__") {
197 return Err(format!("Field \"{}.{}\" must not have a name which begins with the characters \"__\" (two underscores)", obj.name, field.name).into());
198 }
199
200 if let Some(ty) = self.types.get(field.ty.type_name()) {
203 if !ty.is_input_type() {
204 return Err(format!(
205 "Field \"{}.{}\" must accept a input type",
206 obj.name, field.name
207 )
208 .into());
209 }
210 }
211
212 if obj.oneof {
213 if !field.ty.is_nullable() {
215 return Err(format!(
216 "Field \"{}.{}\" must be nullable",
217 obj.name, field.name
218 )
219 .into());
220 }
221
222 if field.default_value.is_some() {
224 return Err(format!(
225 "Field \"{}.{}\" must not have a default value",
226 obj.name, field.name
227 )
228 .into());
229 }
230 }
231 }
232
233 self.check_input_object_reference(&obj.name, &obj, &mut HashSet::new())?;
238 }
239 }
240
241 Ok(())
242 }
243
244 fn check_input_object_reference<'a>(
245 &'a self,
246 current: &str,
247 obj: &'a InputObject,
248 ref_chain: &mut HashSet<&'a str>,
249 ) -> Result<(), SchemaError> {
250 fn typeref_nonnullable_name(ty: &TypeRef) -> Option<&str> {
251 match ty {
252 TypeRef::NonNull(inner) => match inner.as_ref() {
253 TypeRef::Named(name) => Some(name),
254 _ => None,
255 },
256 _ => None,
257 }
258 }
259
260 for field in obj.fields.values() {
261 if let Some(this_name) = typeref_nonnullable_name(&field.ty) {
262 if this_name == current {
263 return Err(format!("\"{}\" references itself either directly or through referenced Input Objects, at least one of the fields in the chain of references must be either a nullable or a List type.", current).into());
264 } else if let Some(obj) = self
265 .types
266 .get(field.ty.type_name())
267 .and_then(Type::as_input_object)
268 {
269 if ref_chain.insert(this_name) {
273 self.check_input_object_reference(current, obj, ref_chain)?;
274 ref_chain.remove(this_name);
275 }
276 }
277 }
278 }
279
280 Ok(())
281 }
282
283 fn check_interfaces(&self) -> Result<(), SchemaError> {
284 for ty in self.types.values() {
286 if let Type::Interface(interface) = ty {
287 for field in interface.fields.values() {
288 if field.name.starts_with("__") {
291 return Err(format!("Field \"{}.{}\" must not have a name which begins with the characters \"__\" (two underscores)", interface.name, field.name).into());
292 }
293
294 if let Some(ty) = self.types.get(field.ty.type_name()) {
296 if !ty.is_output_type() {
297 return Err(format!(
298 "Field \"{}.{}\" must return a output type",
299 interface.name, field.name
300 )
301 .into());
302 }
303 }
304
305 for arg in field.arguments.values() {
306 if arg.name.starts_with("__") {
309 return Err(format!("Argument \"{}.{}.{}\" must not have a name which begins with the characters \"__\" (two underscores)", interface.name, field.name, arg.name).into());
310 }
311
312 if let Some(ty) = self.types.get(arg.ty.type_name()) {
315 if !ty.is_input_type() {
316 return Err(format!(
317 "Argument \"{}.{}.{}\" must accept a input type",
318 interface.name, field.name, arg.name
319 )
320 .into());
321 }
322 }
323 }
324
325 if interface.implements.contains(&interface.name) {
328 return Err(format!(
329 "Interface \"{}\" may not implement itself",
330 interface.name
331 )
332 .into());
333 }
334
335 for interface_name in &interface.implements {
338 if let Some(ty) = self.types.get(interface_name) {
339 let implemenented_type = ty.as_interface().ok_or_else(|| {
340 format!("Type \"{}\" is not interface", interface_name)
341 })?;
342 check_is_valid_implementation(interface, implemenented_type)?;
343 }
344 }
345 }
346 }
347 }
348
349 Ok(())
350 }
351
352 fn check_unions(&self) -> Result<(), SchemaError> {
353 for ty in self.types.values() {
355 if let Type::Union(union) = ty {
356 for type_name in &union.possible_types {
361 if let Some(ty) = self.types.get(type_name) {
362 if ty.as_object().is_none() {
363 return Err(format!(
364 "Member \"{}\" of union \"{}\" is not an object",
365 type_name, union.name
366 )
367 .into());
368 }
369 }
370 }
371 }
372 }
373
374 Ok(())
375 }
376}
377
378fn check_is_valid_implementation(
379 implementing_type: &impl BaseContainer,
380 implemented_type: &Interface,
381) -> Result<(), SchemaError> {
382 for field in implemented_type.fields.values() {
383 let impl_field = implementing_type.field(&field.name).ok_or_else(|| {
384 format!(
385 "{} \"{}\" requires field \"{}\" defined by interface \"{}\"",
386 implementing_type.graphql_type(),
387 implementing_type.name(),
388 field.name,
389 implemented_type.name
390 )
391 })?;
392
393 for arg in field.arguments.values() {
394 let impl_arg = match impl_field.argument(&arg.name) {
395 Some(impl_arg) => impl_arg,
396 None if !arg.ty.is_nullable() => {
397 return Err(format!(
398 "Field \"{}.{}\" requires argument \"{}\" defined by interface \"{}.{}\"",
399 implementing_type.name(),
400 field.name,
401 arg.name,
402 implemented_type.name,
403 field.name,
404 )
405 .into());
406 }
407 None => continue,
408 };
409
410 if !arg.ty.is_subtype(&impl_arg.ty) {
411 return Err(format!(
412 "Argument \"{}.{}.{}\" is not sub-type of \"{}.{}.{}\"",
413 implemented_type.name,
414 field.name,
415 arg.name,
416 implementing_type.name(),
417 field.name,
418 arg.name
419 )
420 .into());
421 }
422 }
423
424 if !impl_field.ty().is_subtype(&field.ty) {
427 return Err(format!(
428 "Field \"{}.{}\" is not sub-type of \"{}.{}\"",
429 implementing_type.name(),
430 field.name,
431 implemented_type.name,
432 field.name,
433 )
434 .into());
435 }
436 }
437
438 Ok(())
439}
440
441#[cfg(test)]
442mod tests {
443 use crate::{
444 Value,
445 dynamic::{
446 Field, FieldFuture, InputObject, InputValue, Object, Schema, SchemaBuilder, TypeRef,
447 },
448 };
449
450 fn base_schema() -> SchemaBuilder {
451 let query = Object::new("Query").field(Field::new("dummy", TypeRef::named("Int"), |_| {
452 FieldFuture::new(async { Ok(Some(Value::from(42))) })
453 }));
454 Schema::build("Query", None, None).register(query)
455 }
456
457 #[test]
458 fn test_recursive_input_objects() {
459 let top_level = InputObject::new("TopLevel")
460 .field(InputValue::new("mid", TypeRef::named_nn("MidLevel")));
461 let mid_level = InputObject::new("MidLevel")
462 .field(InputValue::new("bottom", TypeRef::named("BotLevel")))
463 .field(InputValue::new(
464 "list_bottom",
465 TypeRef::named_nn_list_nn("BotLevel"),
466 ));
467 let bot_level = InputObject::new("BotLevel")
468 .field(InputValue::new("top", TypeRef::named_nn("TopLevel")));
469 let schema = base_schema()
470 .register(top_level)
471 .register(mid_level)
472 .register(bot_level);
473 schema.finish().unwrap();
474 }
475
476 #[test]
477 fn test_recursive_input_objects_bad() {
478 let top_level = InputObject::new("TopLevel")
479 .field(InputValue::new("mid", TypeRef::named_nn("MidLevel")));
480 let mid_level = InputObject::new("MidLevel")
481 .field(InputValue::new("bottom", TypeRef::named_nn("BotLevel")));
482 let bot_level = InputObject::new("BotLevel")
483 .field(InputValue::new("top", TypeRef::named_nn("TopLevel")));
484 let schema = base_schema()
485 .register(top_level)
486 .register(mid_level)
487 .register(bot_level);
488 schema.finish().unwrap_err();
489 }
490
491 #[test]
492 fn test_recursive_input_objects_local_cycle() {
493 let top_level = InputObject::new("TopLevel")
494 .field(InputValue::new("mid", TypeRef::named_nn("MidLevel")));
495 let mid_level = InputObject::new("MidLevel")
496 .field(InputValue::new("bottom", TypeRef::named_nn("BotLevel")));
497 let bot_level = InputObject::new("BotLevel")
498 .field(InputValue::new("mid", TypeRef::named_nn("MidLevel")));
499 let schema = base_schema()
500 .register(top_level)
501 .register(mid_level)
502 .register(bot_level);
503 schema.finish().unwrap_err();
504 }
505}