1#![allow(dead_code)]
2use std::{
3 collections::{btree_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet, VecDeque},
4 ops::Add,
5 sync::{Arc, OnceLock},
6};
7
8use async_graphql_parser::{
9 parse_schema,
10 types::{
11 BaseType, DirectiveDefinition, FieldDefinition, ObjectType, SchemaDefinition,
12 ServiceDocument, TypeDefinition, TypeKind, TypeSystemDefinition,
13 },
14 Positioned,
15};
16
17pub use ::async_graphql_parser::Error;
18use async_graphql_value::Name;
19use itertools::Itertools;
20use serde::{Deserialize, Serialize};
21
22use crate::ir::Type;
23use crate::util::{BTreeMapTryInsertExt, HashMapTryInsertExt};
24
25use self::error::InvalidSchemaError;
26
27mod adapter;
28pub mod error;
29
30pub use adapter::SchemaAdapter;
31
32#[derive(Debug, Clone)]
33pub struct Schema {
34 pub(crate) schema: SchemaDefinition,
35 pub(crate) query_type: ObjectType,
36 pub(crate) directives: HashMap<Arc<str>, DirectiveDefinition>,
37 pub(crate) scalars: HashMap<Arc<str>, TypeDefinition>,
38 pub(crate) vertex_types: HashMap<Arc<str>, TypeDefinition>,
39 pub(crate) fields: HashMap<(Arc<str>, Arc<str>), FieldDefinition>,
40 pub(crate) field_origins: BTreeMap<(Arc<str>, Arc<str>), FieldOrigin>,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub(crate) enum FieldOrigin {
45 SingleAncestor(Arc<str>), MultipleAncestors(BTreeSet<Arc<str>>),
47}
48
49impl Add for &FieldOrigin {
50 type Output = FieldOrigin;
51
52 fn add(self, rhs: Self) -> Self::Output {
53 match (self, rhs) {
54 (FieldOrigin::SingleAncestor(l), FieldOrigin::SingleAncestor(r)) => {
55 if l == r {
56 self.clone()
57 } else {
58 FieldOrigin::MultipleAncestors(btreeset![l.clone(), r.clone()])
59 }
60 }
61 (FieldOrigin::SingleAncestor(single), FieldOrigin::MultipleAncestors(multi))
62 | (FieldOrigin::MultipleAncestors(multi), FieldOrigin::SingleAncestor(single)) => {
63 let mut new_set = multi.clone();
64 new_set.insert(single.clone());
65 FieldOrigin::MultipleAncestors(new_set)
66 }
67 (FieldOrigin::MultipleAncestors(l_set), FieldOrigin::MultipleAncestors(r_set)) => {
68 let mut new_set = l_set.clone();
69 new_set.extend(r_set.iter().cloned());
70 FieldOrigin::MultipleAncestors(new_set)
71 }
72 }
73 }
74}
75
76static BUILTIN_SCALARS: OnceLock<HashSet<&'static str>> = OnceLock::new();
77
78pub(crate) fn get_builtin_scalars() -> &'static HashSet<&'static str> {
79 BUILTIN_SCALARS.get_or_init(|| {
80 hashset! {
81 "Int",
82 "Float",
83 "String",
84 "Boolean",
85 "ID",
86 }
87 })
88}
89
90const RESERVED_PREFIX: &str = "__";
91
92impl Schema {
93 pub const ALL_DIRECTIVE_DEFINITIONS: &'static str = "
94directive @filter(op: String!, value: [String!]) repeatable on FIELD | INLINE_FRAGMENT
95directive @tag(name: String) on FIELD
96directive @output(name: String) on FIELD
97directive @optional on FIELD
98directive @recurse(depth: Int!) on FIELD
99directive @fold on FIELD
100directive @transform(op: String!) on FIELD
101";
102
103 pub fn parse(input: impl AsRef<str>) -> Result<Self, InvalidSchemaError> {
104 let doc = parse_schema(input)?;
105 Self::new(doc)
106 }
107
108 pub fn new(doc: ServiceDocument) -> Result<Self, InvalidSchemaError> {
109 let mut schema: Option<SchemaDefinition> = None;
110 let mut directives: HashMap<Arc<str>, DirectiveDefinition> = Default::default();
111 let mut scalars: HashMap<Arc<str>, TypeDefinition> = Default::default();
112
113 let mut vertex_types: HashMap<Arc<str>, TypeDefinition> =
116 HashMap::with_capacity(doc.definitions.len() - 1);
117
118 let mut fields: HashMap<(Arc<str>, Arc<str>), FieldDefinition> =
120 HashMap::with_capacity(doc.definitions.len() - 1);
121
122 for definition in doc.definitions {
123 match definition {
124 TypeSystemDefinition::Schema(s) => {
125 assert!(schema.is_none());
126 if s.node.extend {
127 unimplemented!("Trustfall does not support extending schemas");
128 }
129
130 schema = Some(s.node);
131 }
132 TypeSystemDefinition::Directive(d) => {
133 directives
134 .insert_or_error(Arc::from(d.node.name.node.to_string()), d.node)
135 .unwrap();
136 }
137 TypeSystemDefinition::Type(t) => {
138 let node = t.node;
139 let type_name: Arc<str> = Arc::from(node.name.node.to_string());
140 assert!(!get_builtin_scalars().contains(type_name.as_ref()));
141
142 if node.extend {
143 unimplemented!("Trustfall does not support extending schemas");
144 }
145
146 match &node.kind {
147 TypeKind::Scalar => {
148 scalars.insert_or_error(type_name.clone(), node.clone()).unwrap();
149 }
150 TypeKind::Object(_) | TypeKind::Interface(_) => {
151 match vertex_types.insert_or_error(type_name.clone(), node.clone()) {
152 Ok(_) => {}
153 Err(err) => {
154 let type_or_interface_name = err.entry.key();
155 return Err(
156 InvalidSchemaError::DuplicateTypeOrInterfaceDefinition(
157 type_or_interface_name.to_string(),
158 ),
159 );
160 }
161 }
162 }
163 TypeKind::Enum(_) => unimplemented!("Trustfall does not support enum's"),
164 TypeKind::Union(_) => unimplemented!("Trustfall does not support unions's"),
165 TypeKind::InputObject(_) => {
166 unimplemented!("Trustfall does not support input objects's")
167 }
168 }
169
170 let field_defs = match node.kind {
171 TypeKind::Object(object) => Some(object.fields),
172
173 TypeKind::Interface(interface) => Some(interface.fields),
174 _ => None,
175 };
176 if let Some(field_defs) = field_defs {
177 for field in field_defs {
178 let field_node = field.node;
179 let field_name = Arc::from(field_node.name.node.to_string());
180
181 match fields
182 .insert_or_error((type_name.clone(), field_name), field_node)
183 {
184 Ok(_) => {}
185 Err(err) => {
186 let (key, value) = err.entry.key();
187 return Err(InvalidSchemaError::DuplicateFieldDefinition(
188 key.to_string(),
189 value.to_string(),
190 ));
191 }
192 }
193 }
194 }
195 }
196 }
197 }
198
199 let schema = schema.expect("Schema definition was not present.");
200 let query_type_name =
201 schema.query.as_ref().expect("No query type was declared in the schema").node.as_ref();
202 let query_type_definition = vertex_types
203 .get(query_type_name)
204 .expect("The query type set in the schema object was never defined.");
205 let query_type = match &query_type_definition.kind {
206 TypeKind::Object(o) => o.clone(),
207 _ => unreachable!(),
208 };
209
210 let mut errors = vec![];
211 if let Err(e) = check_required_transitive_implementations(&vertex_types) {
212 errors.extend(e);
213 }
214 if let Err(e) = check_field_type_narrowing(&vertex_types, &fields) {
215 errors.extend(e);
216 }
217 if let Err(e) = check_fields_required_by_interface_implementations(&vertex_types, &fields) {
218 errors.extend(e);
219 }
220 if let Err(e) =
221 check_type_and_property_and_edge_invariants(query_type_definition, &vertex_types)
222 {
223 errors.extend(e);
224 }
225 if let Err(e) = check_root_query_type_invariants(query_type_definition, &query_type) {
226 errors.extend(e);
227 }
228
229 let field_origins = match get_field_origins(&vertex_types) {
230 Ok(field_origins) => {
231 if let Err(e) = check_ambiguous_field_origins(&fields, &field_origins) {
232 errors.extend(e);
233 }
234 Some(field_origins)
235 }
236 Err(e) => {
237 errors.push(e);
238 None
239 }
240 };
241
242 if errors.is_empty() {
243 Ok(Self {
244 schema,
245 query_type,
246 directives,
247 scalars,
248 vertex_types,
249 fields,
250 field_origins: field_origins.expect("no field origins but also no errors"),
251 })
252 } else {
253 Err(errors.into())
254 }
255 }
256
257 pub fn subtypes<'a, 'slf: 'a>(
260 &'slf self,
261 type_name: &'a str,
262 ) -> Option<impl Iterator<Item = &'slf str> + 'a> {
263 if !self.vertex_types.contains_key(type_name) {
264 return None;
265 }
266
267 Some(self.vertex_types.iter().sorted_by_key(|(name, _)| *name).filter_map(
268 move |(name, defn)| {
269 if name.as_ref() == type_name
270 || get_vertex_type_implements(defn).iter().any(|x| x.node.as_ref() == type_name)
271 {
272 Some(name.as_ref())
273 } else {
274 None
275 }
276 },
277 ))
278 }
279
280 pub(crate) fn query_type_name(&self) -> &str {
281 self.schema.query.as_ref().unwrap().node.as_ref()
282 }
283
284 pub(crate) fn vertex_type_implements(&self, vertex_type: &str) -> &[Positioned<Name>] {
285 get_vertex_type_implements(&self.vertex_types[vertex_type])
286 }
287
288 pub(crate) fn is_subtype(
289 &self,
290 parent_type: &async_graphql_parser::types::Type,
291 maybe_subtype: &async_graphql_parser::types::Type,
292 ) -> bool {
293 is_subtype(&self.vertex_types, parent_type, maybe_subtype)
294 }
295
296 pub(crate) fn is_named_type_subtype(&self, parent_type: &str, maybe_subtype: &str) -> bool {
297 is_named_type_subtype(&self.vertex_types, parent_type, maybe_subtype)
298 }
299}
300
301fn check_root_query_type_invariants(
302 query_type_definition: &TypeDefinition,
303 query_type: &ObjectType,
304) -> Result<(), Vec<InvalidSchemaError>> {
305 let mut errors: Vec<InvalidSchemaError> = vec![];
306
307 for field_defn in &query_type.fields {
308 let field_type = Type::from_type(&field_defn.node.ty.node);
309 let base_named_type = field_type.base_type();
310 if get_builtin_scalars().contains(base_named_type) {
311 errors.push(InvalidSchemaError::PropertyFieldOnRootQueryType(
312 query_type_definition.name.node.to_string(),
313 field_defn.node.name.node.to_string(),
314 field_type.to_string(),
315 ));
316 }
317
318 }
325
326 if errors.is_empty() {
327 Ok(())
328 } else {
329 Err(errors)
330 }
331}
332
333fn check_type_and_property_and_edge_invariants(
334 query_type_definition: &TypeDefinition,
335 vertex_types: &HashMap<Arc<str>, TypeDefinition>,
336) -> Result<(), Vec<InvalidSchemaError>> {
337 let mut errors: Vec<InvalidSchemaError> = vec![];
338
339 for (type_name, type_defn) in vertex_types.iter().sorted_by_key(|(name, _)| *name) {
340 if type_name.as_ref().starts_with(RESERVED_PREFIX) {
341 errors.push(InvalidSchemaError::ReservedTypeName(type_name.to_string()));
342 }
343
344 let type_fields = get_vertex_type_fields(type_defn);
345
346 for defn in type_fields {
347 let field_defn = &defn.node;
348 let field_type = &field_defn.ty.node;
349
350 if field_defn.name.node.as_ref().starts_with(RESERVED_PREFIX) {
351 errors.push(InvalidSchemaError::ReservedFieldName(
352 type_name.to_string(),
353 field_defn.name.node.to_string(),
354 ));
355 }
356
357 let field_type = Type::from_type(field_type);
358
359 let base_named_type = field_type.base_type();
360 if get_builtin_scalars().contains(base_named_type) {
361 if !field_defn.arguments.is_empty() {
363 errors.push(InvalidSchemaError::PropertyFieldWithParameters(
364 type_name.to_string(),
365 field_defn.name.node.to_string(),
366 field_type.to_string(),
367 field_defn.arguments.iter().map(|x| x.node.name.node.to_string()).collect(),
368 ));
369 }
370 } else if vertex_types.contains_key(base_named_type) {
371 if base_named_type == query_type_definition.name.node.as_ref() {
373 errors.push(InvalidSchemaError::EdgePointsToRootQueryType(
375 type_name.to_string(),
376 field_defn.name.node.to_string(),
377 field_type.to_string(),
378 ));
379 } else {
380 for param_defn in &field_defn.arguments {
382 if let Some(value) = ¶m_defn.node.default_value {
383 let param_type = ¶m_defn.node.ty.node;
384 match value.node.clone().try_into() {
385 Ok(value) => {
386 if !Type::from_type(param_type).is_valid_value(&value) {
387 errors.push(InvalidSchemaError::InvalidDefaultValueForFieldParameter(
388 type_name.to_string(),
389 field_defn.name.node.to_string(),
390 param_defn.node.name.node.to_string(),
391 param_type.to_string(),
392 format!("{value:?}"),
393 ));
394 }
395 }
396 Err(_) => {
397 errors.push(
398 InvalidSchemaError::InvalidDefaultValueForFieldParameter(
399 type_name.to_string(),
400 field_defn.name.node.to_string(),
401 param_defn.node.name.node.to_string(),
402 param_type.to_string(),
403 value.node.to_string(),
404 ),
405 );
406 }
407 }
408 }
409 }
410
411 if let Some(inner_list) = field_type.as_list() {
414 if inner_list.is_list() {
415 errors.push(InvalidSchemaError::InvalidEdgeType(
416 type_name.to_string(),
417 field_defn.name.node.to_string(),
418 field_type.to_string(),
419 ));
420 }
421 }
422 }
423 } else {
424 errors.push(InvalidSchemaError::UnknownPropertyOrEdgeType(
425 field_defn.name.node.as_ref().to_string(),
426 field_type.to_string(),
427 ))
428 }
429 }
430 }
431
432 if errors.is_empty() {
433 Ok(())
434 } else {
435 Err(errors)
436 }
437}
438
439fn is_named_type_subtype(
440 vertex_types: &HashMap<Arc<str>, TypeDefinition>,
441 parent_type: &str,
442 maybe_subtype: &str,
443) -> bool {
444 let parent_is_vertex = vertex_types.contains_key(parent_type);
445 let maybe_sub = vertex_types.get(maybe_subtype);
446
447 match (parent_is_vertex, maybe_sub) {
448 (false, None) => {
449 parent_type == maybe_subtype
452 }
453 (true, Some(maybe_subtype_vertex)) => {
454 parent_type == maybe_subtype
458 || get_vertex_type_implements(maybe_subtype_vertex)
459 .iter()
460 .any(|pos| pos.node.as_ref() == parent_type)
461 }
462 _ => {
463 false
466 }
467 }
468}
469
470fn is_subtype(
471 vertex_types: &HashMap<Arc<str>, TypeDefinition>,
472 parent_type: &async_graphql_parser::types::Type,
473 maybe_subtype: &async_graphql_parser::types::Type,
474) -> bool {
475 if !parent_type.nullable && maybe_subtype.nullable {
478 return false;
479 }
480
481 match (&parent_type.base, &maybe_subtype.base) {
482 (BaseType::Named(parent), BaseType::Named(subtype)) => {
483 is_named_type_subtype(vertex_types, parent.as_ref(), subtype.as_ref())
484 }
485 (BaseType::List(parent_type), BaseType::List(maybe_subtype)) => {
486 is_subtype(vertex_types, parent_type, maybe_subtype)
487 }
488 (BaseType::Named(..), BaseType::List(..)) | (BaseType::List(..), BaseType::Named(..)) => {
489 false
490 }
491 }
492}
493
494fn check_ambiguous_field_origins(
495 fields: &HashMap<(Arc<str>, Arc<str>), FieldDefinition>,
496 field_origins: &BTreeMap<(Arc<str>, Arc<str>), FieldOrigin>,
497) -> Result<(), Vec<InvalidSchemaError>> {
498 let mut errors = vec![];
499
500 for (key, origin) in field_origins {
501 let (type_name, field_name) = key;
502 if let FieldOrigin::MultipleAncestors(ancestors) = &origin {
503 let field_type = fields[key].ty.node.to_string();
504 errors.push(InvalidSchemaError::AmbiguousFieldOrigin(
505 type_name.to_string(),
506 field_name.to_string(),
507 field_type,
508 ancestors.iter().map(|x| x.to_string()).collect(),
509 ))
510 }
511 }
512
513 if errors.is_empty() {
514 Ok(())
515 } else {
516 Err(errors)
517 }
518}
519
520fn check_required_transitive_implementations(
528 vertex_types: &HashMap<Arc<str>, TypeDefinition>,
529) -> Result<(), Vec<InvalidSchemaError>> {
530 let mut errors: Vec<InvalidSchemaError> = vec![];
531
532 for (type_name, type_defn) in vertex_types.iter().sorted_by_key(|(name, _)| *name) {
533 let implementations: BTreeSet<&str> =
534 get_vertex_type_implements(type_defn).iter().map(|x| x.node.as_ref()).collect();
535
536 for implements_type in implementations.iter().copied() {
538 match vertex_types.get(implements_type) {
539 Some(implementation_defn) => {
540 if !matches!(implementation_defn.kind, TypeKind::Interface(..)) {
541 errors.push(InvalidSchemaError::ImplementingNonInterface(
542 type_name.to_string(),
543 implements_type.to_string(),
544 ));
545 } else {
546 for expected_impl in get_vertex_type_implements(implementation_defn) {
547 let expected_impl_name = expected_impl.node.as_ref();
548
549 if expected_impl_name != type_name.as_ref()
553 && !implementations.contains(expected_impl_name)
554 {
555 errors.push(
556 InvalidSchemaError::MissingTransitiveInterfaceImplementation(
557 type_name.to_string(),
558 implements_type.to_string(),
559 expected_impl_name.to_string(),
560 ),
561 );
562 }
563 }
564 }
565 }
566 None => {
567 errors.push(InvalidSchemaError::ImplementingNonExistentType(
568 type_name.to_string(),
569 implements_type.to_string(),
570 ));
571 }
572 }
573 }
574 }
575
576 if errors.is_empty() {
577 Ok(())
578 } else {
579 Err(errors)
580 }
581}
582
583fn check_fields_required_by_interface_implementations(
584 vertex_types: &HashMap<Arc<str>, TypeDefinition>,
585 fields: &HashMap<(Arc<str>, Arc<str>), FieldDefinition>,
586) -> Result<(), Vec<InvalidSchemaError>> {
587 let mut errors: Vec<InvalidSchemaError> = vec![];
588
589 for (type_name, type_defn) in vertex_types.iter().sorted_by_key(|(name, _)| *name) {
590 let implementations = get_vertex_type_implements(type_defn);
591
592 for implementation in implementations {
593 let implementation = implementation.node.as_ref();
594 let Some(impl_defn) = vertex_types.get(implementation) else {
595 continue;
596 };
597
598 for field in get_vertex_type_fields(impl_defn) {
599 let field_name = field.node.name.node.as_ref();
600
601 if !fields.contains_key(&(type_name.clone(), Arc::from(field_name))) {
604 errors.push(InvalidSchemaError::MissingRequiredField(
605 type_name.to_string(),
606 implementation.to_string(),
607 field_name.to_string(),
608 field.node.ty.node.to_string(),
609 ))
610 }
611 }
612 }
613 }
614
615 if errors.is_empty() {
616 Ok(())
617 } else {
618 Err(errors)
619 }
620}
621
622fn check_field_type_narrowing(
623 vertex_types: &HashMap<Arc<str>, TypeDefinition>,
624 fields: &HashMap<(Arc<str>, Arc<str>), FieldDefinition>,
625) -> Result<(), Vec<InvalidSchemaError>> {
626 let mut errors: Vec<InvalidSchemaError> = vec![];
627
628 for (type_name, type_defn) in vertex_types.iter().sorted_by_key(|(name, _)| *name) {
629 let implementations = get_vertex_type_implements(type_defn);
630 let type_fields = get_vertex_type_fields(type_defn);
631
632 for field in type_fields {
633 let field_name = field.node.name.node.as_ref();
634 let field_type = &field.node.ty.node;
635 let field_parameters: BTreeMap<_, _> = field
636 .node
637 .arguments
638 .iter()
639 .map(|arg| (arg.node.name.node.as_ref(), &arg.node.ty.node))
640 .collect();
641
642 for implementation in implementations {
643 let implementation = implementation.node.as_ref();
644
645 if let Some(parent_field) =
648 fields.get(&(Arc::from(implementation), Arc::from(field_name)))
649 {
650 let parent_field_type = &parent_field.ty.node;
651 if !is_subtype(vertex_types, parent_field_type, field_type) {
652 errors.push(InvalidSchemaError::InvalidTypeWideningOfInheritedField(
653 field_name.to_string(),
654 type_name.to_string(),
655 implementation.to_string(),
656 field_type.to_string(),
657 parent_field_type.to_string(),
658 ));
659 }
660
661 let parent_field_parameters: BTreeMap<_, _> = parent_field
662 .arguments
663 .iter()
664 .map(|arg| (arg.node.name.node.as_ref(), &arg.node.ty.node))
665 .collect();
666
667 let missing_parameters = parent_field_parameters
670 .keys()
671 .copied()
672 .filter(|name| !field_parameters.contains_key(*name))
673 .collect_vec();
674 if !missing_parameters.is_empty() {
675 errors.push(InvalidSchemaError::InheritedFieldMissingParameters(
676 field_name.to_owned(),
677 type_name.to_string(),
678 implementation.to_owned(),
679 missing_parameters.into_iter().map(ToOwned::to_owned).collect_vec(),
680 ));
681 }
682
683 let unexpected_parameters = field_parameters
686 .keys()
687 .copied()
688 .filter(|name| !parent_field_parameters.contains_key(*name))
689 .collect_vec();
690 if !unexpected_parameters.is_empty() {
691 errors.push(InvalidSchemaError::InheritedFieldUnexpectedParameters(
692 field_name.to_owned(),
693 type_name.to_string(),
694 implementation.to_owned(),
695 unexpected_parameters.into_iter().map(ToOwned::to_owned).collect_vec(),
696 ));
697 }
698
699 for (&field_parameter, &field_type) in &field_parameters {
703 if let Some(&parent_field_type) =
704 parent_field_parameters.get(field_parameter)
705 {
706 if !Type::from_type(field_type)
707 .is_scalar_only_subtype(&Type::from_type(parent_field_type))
708 {
709 errors.push(InvalidSchemaError::InvalidTypeNarrowingOfInheritedFieldParameter(
710 field_name.to_owned(),
711 type_name.to_string(),
712 implementation.to_owned(),
713 field_parameter.to_string(),
714 field_type.to_string(),
715 parent_field_type.to_string(),
716 ));
717 }
718 }
719 }
720 }
721 }
722 }
723 }
724
725 if errors.is_empty() {
726 Ok(())
727 } else {
728 Err(errors)
729 }
730}
731
732fn get_vertex_type_fields(vertex: &TypeDefinition) -> &[Positioned<FieldDefinition>] {
733 match &vertex.kind {
734 TypeKind::Object(obj) => &obj.fields,
735 TypeKind::Interface(iface) => &iface.fields,
736 _ => unreachable!(),
737 }
738}
739
740fn get_vertex_type_implements(vertex: &TypeDefinition) -> &[Positioned<Name>] {
741 match &vertex.kind {
742 TypeKind::Object(obj) => &obj.implements,
743 TypeKind::Interface(iface) => &iface.implements,
744 _ => unreachable!(),
745 }
746}
747
748#[allow(clippy::type_complexity)]
749fn get_field_origins(
750 vertex_types: &HashMap<Arc<str>, TypeDefinition>,
751) -> Result<BTreeMap<(Arc<str>, Arc<str>), FieldOrigin>, InvalidSchemaError> {
752 let mut field_origins: BTreeMap<(Arc<str>, Arc<str>), FieldOrigin> = Default::default();
753 let mut queue = VecDeque::new();
754
755 let mut required_resolutions: BTreeMap<&str, BTreeSet<&str>> = vertex_types
757 .iter()
758 .sorted_by_key(|(name, _)| *name)
759 .map(|(name, defn)| {
760 let resolutions: BTreeSet<&str> = get_vertex_type_implements(defn)
761 .iter()
762 .map(|x| x.node.as_ref())
763 .filter(|name| vertex_types.contains_key(*name)) .collect();
765 if resolutions.is_empty() {
766 queue.push_back(name);
767 }
768 (name.as_ref(), resolutions)
769 })
770 .collect();
771
772 let resolvers: BTreeMap<&str, BTreeSet<Arc<str>>> = vertex_types
774 .iter()
775 .sorted_by_key(|(name, _)| *name)
776 .flat_map(|(name, defn)| {
777 get_vertex_type_implements(defn)
778 .iter()
779 .map(|x| (x.node.as_ref(), Arc::from(name.as_ref())))
780 })
781 .fold(Default::default(), |mut acc, (interface, implementer)| {
782 match acc.entry(interface) {
783 Entry::Vacant(v) => {
784 v.insert(btreeset![implementer]);
785 }
786 Entry::Occupied(occ) => {
787 occ.into_mut().insert(implementer);
788 }
789 }
790 acc
791 });
792
793 while let Some(type_name) = queue.pop_front() {
794 let defn = &vertex_types[type_name];
795 let implements = get_vertex_type_implements(defn);
796 let fields = get_vertex_type_fields(defn);
797
798 let mut implemented_fields: BTreeMap<&str, FieldOrigin> = Default::default();
799 for implemented_interface in implements {
800 let implemented_interface = implemented_interface.node.as_ref();
801 let Some(implemented_defn) = vertex_types.get(implemented_interface) else {
802 continue;
803 };
804 let parent_fields = get_vertex_type_fields(implemented_defn);
805 for field in parent_fields {
806 let parent_field_origin = &field_origins
807 [&(Arc::from(implemented_interface), Arc::from(field.node.name.node.as_ref()))];
808
809 implemented_fields
810 .entry(field.node.name.node.as_ref())
811 .and_modify(|origin| *origin = (origin as &FieldOrigin) + parent_field_origin)
812 .or_insert_with(|| parent_field_origin.clone());
813 }
814 }
815
816 for field in fields {
817 let field = &field.node;
818 let field_name = &field.name.node;
819
820 let origin = implemented_fields
821 .remove(field_name.as_ref())
822 .unwrap_or_else(|| FieldOrigin::SingleAncestor(type_name.clone()));
823 field_origins
824 .insert_or_error((type_name.clone(), Arc::from(field_name.as_ref())), origin)
825 .unwrap();
826 }
827
828 if let Some(next_types) = resolvers.get(type_name.as_ref()) {
829 for next_type in next_types.iter() {
830 let remaining = required_resolutions.get_mut(next_type.as_ref()).unwrap();
831 if remaining.remove(type_name.as_ref()) && remaining.is_empty() {
832 queue.push_back(next_type);
833 }
834 }
835 }
836 }
837
838 for (required, mut remaining) in required_resolutions.into_iter() {
839 if !remaining.is_empty() {
840 remaining.insert(required);
841 let circular_implementations =
842 remaining.into_iter().map(|x| x.to_string()).collect_vec();
843 return Err(InvalidSchemaError::CircularImplementsRelationships(
844 circular_implementations,
845 ));
846 }
847 }
848
849 Ok(field_origins)
850}
851
852#[cfg(test)]
853mod tests {
854 use std::{
855 fs,
856 path::{Path, PathBuf},
857 };
858
859 use async_graphql_parser::parse_schema;
860 use itertools::Itertools;
861 use trustfall_filetests_macros::parameterize;
862
863 use super::{error::InvalidSchemaError, Schema};
864
865 #[parameterize("trustfall_core/test_data/tests/schema_errors", "*.graphql")]
866 fn schema_errors(base: &Path, stem: &str) {
867 let mut input_path = PathBuf::from(base);
868 input_path.push(format!("{stem}.graphql"));
869
870 let input_data = fs::read_to_string(input_path).unwrap();
871
872 let mut error_path = PathBuf::from(base);
873 error_path.push(format!("{stem}.schema-error.ron"));
874 let error_data = fs::read_to_string(error_path).unwrap();
875 let expected_error: InvalidSchemaError = ron::from_str(&error_data).unwrap();
876
877 let schema_doc = parse_schema(input_data).unwrap();
878
879 match Schema::new(schema_doc) {
880 Err(e) => {
881 assert_eq!(e, expected_error);
882 }
883 Ok(_) => panic!("Expected an error but got valid schema."),
884 }
885 }
886
887 #[parameterize("trustfall_core/test_data/tests/valid_schemas", "*.graphql")]
888 fn valid_schemas(base: &Path, stem: &str) {
889 let mut input_path = PathBuf::from(base);
890 input_path.push(format!("{stem}.graphql"));
891
892 let input_data = fs::read_to_string(input_path).unwrap();
893
894 assert!(input_data.contains(Schema::ALL_DIRECTIVE_DEFINITIONS));
896
897 match Schema::parse(input_data) {
898 Ok(_) => {}
899 Err(e) => {
900 panic!("{}", e);
901 }
902 }
903 }
904
905 #[test]
906 fn schema_subtypes() {
907 let input_data = include_str!("../../test_data/schemas/numbers.graphql");
908 let schema = Schema::parse(input_data).expect("valid schema");
909
910 assert!(schema.subtypes("Nonexistent").is_none());
911
912 let composite_subtypes = schema.subtypes("Composite").unwrap().collect_vec();
913 assert_eq!(vec!["Composite"], composite_subtypes);
914
915 let mut number_subtypes = schema.subtypes("Number").unwrap().collect_vec();
916 number_subtypes.sort_unstable();
917 assert_eq!(vec!["Composite", "Neither", "Number", "Prime"], number_subtypes);
918 }
919}